diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f16efe76..45e4a01f3 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ - #1396 Create tables from other RDBMS - #1427 Support for CONCAT alias operator - #1424 Add get physical plan with explain +- #1472 Implement predicate pushdown for data providers ## Improvements - #1325 Refactored CacheMachine.h and CacheMachine.cpp diff --git a/README.md b/README.md index a0e18383c..c9b75cea1 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ bc.sql('SELECT passenger_count, trip_distance FROM taxi LIMIT 2') ## Documentation You can find our full documentation at [docs.blazingdb.com](https://docs.blazingdb.com/docs). -# Prerequisites +# Prerequisites * [Anaconda or Miniconda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html) installed * OS Support * Ubuntu 16.04/18.04 LTS @@ -96,7 +96,7 @@ Where $CUDA_VERSION is 10.1, 10.2 or 11.0 and $PYTHON_VERSION is 3.7 or 3.8 *For example for CUDA 10.1 and Python 3.7:* ```bash conda install -c blazingsql -c rapidsai -c nvidia -c conda-forge -c defaults blazingsql python=3.7 cudatoolkit=10.1 -``` +``` ## Nightly Version ```bash diff --git a/engine/src/cython/engine.cpp b/engine/src/cython/engine.cpp index 3a966cba1..6c6e6aeb4 100644 --- a/engine/src/cython/engine.cpp +++ b/engine/src/cython/engine.cpp @@ -25,6 +25,11 @@ #include "../io/data_provider/sql/MySQLDataProvider.h" #endif +#ifdef POSTGRESQL_SUPPORT +#include "../io/data_parser/sql/PostgreSQLParser.h" +#include "../io/data_provider/sql/PostgreSQLDataProvider.h" +#endif + #ifdef SQLITE_SUPPORT #include "../io/data_parser/sql/SQLiteParser.h" #include "../io/data_provider/sql/SQLiteDataProvider.h" @@ -88,12 +93,23 @@ std::pair, std::vector> get_l #else throw std::runtime_error("ERROR: This BlazingSQL version doesn't support MySQL integration"); #endif - } else if(fileType == ral::io::DataType::SQLITE) { + } else if(fileType == ral::io::DataType::POSTGRESQL) { +#ifdef POSTGRESQL_SUPPORT + parser = std::make_shared(); + auto sql = ral::io::getSqlInfo(args_map); + provider = std::make_shared(sql, total_number_of_nodes, self_node_idx); + isSqlProvider = true; +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support PostgreSQL integration"); +#endif + } else if(fileType == ral::io::DataType::SQLITE) { #ifdef SQLITE_SUPPORT parser = std::make_shared(); auto sql = ral::io::getSqlInfo(args_map); provider = std::make_shared(sql, total_number_of_nodes, self_node_idx); isSqlProvider = true; +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support SQLite integration"); #endif } @@ -184,7 +200,7 @@ std::shared_ptr runGenerateGraph(uint32_t masterIndex, { using blazingdb::manager::Context; using blazingdb::transport::Node; - + auto& communicationData = ral::communication::CommunicationData::getInstance(); std::vector contextNodes; diff --git a/engine/src/cython/io.cpp b/engine/src/cython/io.cpp index 43c42c2c3..49cb0f3e7 100644 --- a/engine/src/cython/io.cpp +++ b/engine/src/cython/io.cpp @@ -17,6 +17,11 @@ #include "../io/data_provider/sql/MySQLDataProvider.h" #endif +#ifdef POSTGRESQL_SUPPORT +#include "../io/data_parser/sql/PostgreSQLParser.h" +#include "../io/data_provider/sql/PostgreSQLDataProvider.h" +#endif + #ifdef SQLITE_SUPPORT #include "../io/data_parser/sql/SQLiteParser.h" #include "../io/data_provider/sql/SQLiteDataProvider.h" @@ -71,6 +76,17 @@ TableSchema parseSchema(std::vector files, parser = std::make_shared(); auto sql = ral::io::getSqlInfo(args_map); provider = std::make_shared(sql, 0, 0); +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support MySQL integration"); +#endif + isSqlProvider = true; + } else if(fileType == ral::io::DataType::POSTGRESQL) { +#ifdef POSTGRESQL_SUPPORT + parser = std::make_shared(); + auto sql = ral::io::getSqlInfo(args_map); + provider = std::make_shared(sql, 0, 0); +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support PostgreSQL integration"); #endif isSqlProvider = true; } else if(fileType == ral::io::DataType::SQLITE) { @@ -79,6 +95,8 @@ TableSchema parseSchema(std::vector files, auto sql = ral::io::getSqlInfo(args_map); provider = std::make_shared(sql, 0, 0); isSqlProvider = true; +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support SQLite integration"); #endif } diff --git a/engine/src/execution_graph/logic_controllers/BatchProcessing.cpp b/engine/src/execution_graph/logic_controllers/BatchProcessing.cpp index 6a17af9d4..c79c4e40d 100644 --- a/engine/src/execution_graph/logic_controllers/BatchProcessing.cpp +++ b/engine/src/execution_graph/logic_controllers/BatchProcessing.cpp @@ -8,8 +8,13 @@ #include "io/data_provider/sql/MySQLDataProvider.h" #endif -// TODO percy -//#include "io/data_parser/sql/PostgreSQLParser.h" +#ifdef POSTGRESQL_SUPPORT +#include "io/data_provider/sql/PostgreSQLDataProvider.h" +#endif + +#ifdef SQLITE_SUPPORT +#include "io/data_provider/sql/SQLiteDataProvider.h" +#endif #include "parser/expression_utils.hpp" #include "taskflow/executor.h" @@ -128,6 +133,18 @@ TableScan::TableScan(std::size_t kernel_id, const std::string & queryString, std ral::io::set_sql_projections(provider.get(), get_projections_wrapper(schema.get_num_columns())); #else throw std::runtime_error("ERROR: This BlazingSQL version doesn't support MySQL integration"); +#endif + } else if (parser->type() == ral::io::DataType::POSTGRESQL) { +#ifdef POSTGRESQL_SUPPORT + ral::io::set_sql_projections(provider.get(), get_projections_wrapper(schema.get_num_columns())); +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support PostgreSQL integration"); +#endif + } else if (parser->type() == ral::io::DataType::SQLITE) { +#ifdef SQLITE_SUPPORT + ral::io::set_sql_projections(provider.get(), get_projections_wrapper(schema.get_num_columns())); +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support SQLite integration"); #endif } else { num_batches = provider->get_num_handles(); @@ -237,6 +254,7 @@ BindableTableScan::BindableTableScan(std::size_t kernel_id, const std::string & : kernel(kernel_id, queryString, context, kernel_type::BindableTableScanKernel), provider(provider), parser(parser), schema(schema) { this->query_graph = query_graph; this->filtered = is_filtered_bindable_scan(expression); + this->predicate_pushdown_done = false; if(parser->type() == ral::io::DataType::CUDF || parser->type() == ral::io::DataType::DASK_CUDF){ num_batches = std::max(provider->get_num_handles(), (size_t)1); @@ -263,8 +281,21 @@ BindableTableScan::BindableTableScan(std::size_t kernel_id, const std::string & } else if (parser->type() == ral::io::DataType::MYSQL) { #ifdef MYSQL_SUPPORT ral::io::set_sql_projections(provider.get(), get_projections_wrapper(schema.get_num_columns(), queryString)); + predicate_pushdown_done = ral::io::set_sql_predicate_pushdown(provider.get(), queryString); #else throw std::runtime_error("ERROR: This BlazingSQL version doesn't support MySQL integration"); +#endif + } else if (parser->type() == ral::io::DataType::POSTGRESQL) { +#ifdef POSTGRESQL_SUPPORT + ral::io::set_sql_projections(provider.get(), get_projections_wrapper(schema.get_num_columns(), queryString)); +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support PostgreSQL integration"); +#endif + } else if (parser->type() == ral::io::DataType::SQLITE) { +#ifdef SQLITE_SUPPORT + ral::io::set_sql_projections(provider.get(), get_projections_wrapper(schema.get_num_columns(), queryString)); +#else + throw std::runtime_error("ERROR: This BlazingSQL version doesn't support SQLite integration"); #endif } else { num_batches = provider->get_num_handles(); @@ -278,7 +309,7 @@ ral::execution::task_result BindableTableScan::do_process(std::vector< std::uniq std::unique_ptr filtered_input; try{ - if(this->filtered) { + if(this->filtered && !this->predicate_pushdown_done) { filtered_input = ral::processor::process_filter(input->toBlazingTableView(), expression, this->context.get()); filtered_input->setNames(fix_column_aliases(filtered_input->names(), expression)); output->addToCache(std::move(filtered_input)); diff --git a/engine/src/execution_graph/logic_controllers/BatchProcessing.h b/engine/src/execution_graph/logic_controllers/BatchProcessing.h index e6e4c5b19..8b7e8c2f2 100644 --- a/engine/src/execution_graph/logic_controllers/BatchProcessing.h +++ b/engine/src/execution_graph/logic_controllers/BatchProcessing.h @@ -203,6 +203,7 @@ class BindableTableScan : public kernel { size_t file_index = 0; size_t num_batches; bool filtered; + bool predicate_pushdown_done; }; /** diff --git a/engine/src/io/data_parser/ArgsUtil.cpp b/engine/src/io/data_parser/ArgsUtil.cpp index 07a6b3de5..9c7fd523e 100644 --- a/engine/src/io/data_parser/ArgsUtil.cpp +++ b/engine/src/io/data_parser/ArgsUtil.cpp @@ -276,8 +276,8 @@ sql_info getSqlInfo(std::map &args_map) { if (args_map.find("password") != args_map.end()) { sql.password = args_map.at("password"); } - if (args_map.find("schema") != args_map.end()) { - sql.schema = args_map.at("schema"); + if (args_map.find("database") != args_map.end()) { + sql.schema = args_map.at("database"); } if (args_map.find("table") != args_map.end()) { sql.table = args_map.at("table"); diff --git a/engine/src/io/data_parser/sql/AbstractSQLParser.cpp b/engine/src/io/data_parser/sql/AbstractSQLParser.cpp index 2a6c68560..042bea281 100644 --- a/engine/src/io/data_parser/sql/AbstractSQLParser.cpp +++ b/engine/src/io/data_parser/sql/AbstractSQLParser.cpp @@ -44,15 +44,35 @@ std::unique_ptr abstractsql_parser::parse_batch( { void *src = nullptr; + if (type() == DataType::MYSQL) { #if defined(MYSQL_SUPPORT) - src = handle.sql_handle.mysql_resultset.get(); -#elif defined(POSTGRESQL_SUPPORT) - src = handle.sql_handle.postgresql_result.get(); -#elif defined(SQLITE_SUPPORT) - src = handle.sql_handle.sqlite_statement.get(); + src = handle.sql_handle.mysql_resultset.get(); +#else + throw std::runtime_error( + "Unsupported MySQL parser for this BlazingSQL version"); #endif + } + + if (type() == DataType::POSTGRESQL) { +#if defined(POSTGRESQL_SUPPORT) + src = handle.sql_handle.postgresql_result.get(); +#else + throw std::runtime_error( + "Unsupported PostgreSQL parser for this BlazingSQL version"); +#endif + } - return this->parse_raw_batch(src, schema, column_indices, row_groups, handle.sql_handle.row_count); + if (type() == DataType::SQLITE) { +#if defined(SQLITE_SUPPORT) + src = handle.sql_handle.sqlite_statement.get(); +#else + throw std::runtime_error( + "Unsupported Sqlite3 parser for this BlazingSQL version"); +#endif + } + + return this->parse_raw_batch( + src, schema, column_indices, row_groups, handle.sql_handle.row_count); } void abstractsql_parser::parse_schema(ral::io::data_handle handle, ral::io::Schema & schema) { @@ -277,7 +297,7 @@ std::pair, std::vector>> init case cudf::type_id::NUM_TYPE_IDS: {} break; } } - + return std::make_pair(host_cols, null_masks); } diff --git a/engine/src/io/data_parser/sql/AbstractSQLParser.h b/engine/src/io/data_parser/sql/AbstractSQLParser.h index af44af888..ea2a9343f 100644 --- a/engine/src/io/data_parser/sql/AbstractSQLParser.h +++ b/engine/src/io/data_parser/sql/AbstractSQLParser.h @@ -1,6 +1,5 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Percy Camilo Triveño Aucahuasi */ #ifndef _ABSTRACTSQLPARSER_H_ diff --git a/engine/src/io/data_parser/sql/PostgreSQLParser.cpp b/engine/src/io/data_parser/sql/PostgreSQLParser.cpp index 00c0ea114..751311088 100644 --- a/engine/src/io/data_parser/sql/PostgreSQLParser.cpp +++ b/engine/src/io/data_parser/sql/PostgreSQLParser.cpp @@ -1,10 +1,11 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Cristhian Alberto Gonzales Castillo - * + * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Cristhian Alberto Gonzales Castillo */ #include +#include +#include #include #include @@ -14,533 +15,329 @@ #include "PostgreSQLParser.h" #include "sqlcommon.h" +// TODO(cristhian): To optimize read about ECPG postgresql + namespace ral { namespace io { -static const std::array postgresql_string_type_hints = { - "character", "character varying", "bytea", "text", "anyarray", "name"}; +static const std::array postgresql_string_type_hints = + {"character", "character varying", "bytea", "text", "anyarray", "name"}; -static inline bool postgresql_is_cudf_string(const std::string &hint) { - const auto *result = - std::find_if(std::cbegin(postgresql_string_type_hints), - std::cend(postgresql_string_type_hints), - [&hint](const char *c_hint) { - return std::strcmp(c_hint, hint.c_str()) == 0; - }); +static inline bool postgresql_is_cudf_string(const std::string & hint) { + const auto * result = std::find_if( + std::cbegin(postgresql_string_type_hints), + std::cend(postgresql_string_type_hints), + [&hint](const char * c_hint) { return !hint.rfind(c_hint, 0); }); return result != std::cend(postgresql_string_type_hints); } -static inline cudf::io::table_with_metadata -read_postgresql(const std::shared_ptr &pgResult, - const std::vector &column_indices, - const std::vector &cudf_types, - const std::vector &column_bytes) { - const std::size_t resultNfields = PQnfields(pgResult.get()); - if (resultNfields != column_indices.size() || - resultNfields != column_bytes.size() || - resultNfields != cudf_types.size()) { - throw std::runtime_error( - "Not equal columns for indices and bytes in PostgreSQL filter"); - } - - std::vector host_cols; - host_cols.reserve(resultNfields); - const int resultNtuples = PQntuples(pgResult.get()); - const std::size_t bitmask_allocation = - cudf::bitmask_allocation_size_bytes(resultNtuples); - const std::size_t num_words = bitmask_allocation / sizeof(cudf::bitmask_type); - std::vector> null_masks(resultNfields); - std::transform( - column_indices.cbegin(), - column_indices.cend(), - std::back_inserter(host_cols), - [&pgResult, &cudf_types, &null_masks, num_words, resultNtuples]( - const int projection_index) { - null_masks[projection_index].resize(num_words, 0); - const int fsize = PQfsize(pgResult.get(), projection_index); - if (fsize < 0) { // STRING, STRUCT, LIST, and similar cases - auto *string_col = new cudf_string_col(); - string_col->offsets.reserve(resultNtuples + 1); - string_col->offsets.push_back(0); - return static_cast(string_col); - } - // primitives cases - const cudf::type_id cudf_type_id = cudf_types[projection_index]; - switch (cudf_type_id) { - case cudf::type_id::INT8: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::INT16: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::INT32: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::INT64: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::UINT8: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::UINT16: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::UINT32: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::UINT64: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::BOOL8: { - auto *vector = new std::vector; - vector->reserve(resultNtuples); - return static_cast(vector); - } - case cudf::type_id::STRING: { - auto *string_col = new cudf_string_col(); - string_col->offsets.reserve(resultNtuples + 1); - string_col->offsets.push_back(0); - return static_cast(string_col); - } - default: - throw std::runtime_error("Invalid allocation for cudf type id"); - } - }); - - for (int i = 0; i < resultNtuples; i++) { - for (const std::size_t projection_index : column_indices) { - cudf::type_id cudf_type_id = cudf_types[projection_index]; - const char *resultValue = PQgetvalue(pgResult.get(), i, projection_index); - const bool isNull = - static_cast(PQgetisnull(pgResult.get(), i, projection_index)); - switch (cudf_type_id) { - case cudf::type_id::INT8: { - const std::int8_t castedValue = - *reinterpret_cast(resultValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(castedValue); - break; - } - case cudf::type_id::INT16: { - const std::int16_t castedValue = - *reinterpret_cast(resultValue); - const std::int16_t hostOrderedValue = ntohs(castedValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(hostOrderedValue); - break; - } - case cudf::type_id::INT32: { - const std::int32_t castedValue = - *reinterpret_cast(resultValue); - const std::int32_t hostOrderedValue = ntohl(castedValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(hostOrderedValue); - break; - } - case cudf::type_id::INT64: { - const std::int64_t castedValue = - *reinterpret_cast(resultValue); - const std::int64_t hostOrderedValue = ntohl(castedValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(hostOrderedValue); - break; - } - case cudf::type_id::UINT8: { - const std::int8_t castedValue = - *reinterpret_cast(resultValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(castedValue); - break; - } - case cudf::type_id::UINT16: { - const std::int16_t castedValue = - *reinterpret_cast(resultValue); - const std::int16_t hostOrderedValue = ntohs(castedValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(hostOrderedValue); - break; - } - case cudf::type_id::UINT32: { - const std::int32_t castedValue = - *reinterpret_cast(resultValue); - const std::int32_t hostOrderedValue = ntohl(castedValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(hostOrderedValue); - break; - } - case cudf::type_id::UINT64: { - const std::int64_t castedValue = - *reinterpret_cast(resultValue); - const std::int64_t hostOrderedValue = ntohl(castedValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(hostOrderedValue); - break; - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - const std::int32_t castedValue = - *reinterpret_cast(resultValue); - const std::int32_t hostOrderedValue = ntohl(castedValue); - const float floatCastedValue = - *reinterpret_cast(&hostOrderedValue); - std::vector &vector = *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(floatCastedValue); - break; - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - const std::int64_t castedValue = - *reinterpret_cast(resultValue); - const std::int64_t hostOrderedValue = ntohl(castedValue); - const double doubleCastedValue = - *reinterpret_cast(&hostOrderedValue); - std::vector &vector = *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(doubleCastedValue); - break; - } - case cudf::type_id::BOOL8: { - const std::uint8_t castedValue = - *reinterpret_cast(resultValue); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(castedValue); - break; - } - case cudf::type_id::STRING: { - cudf_string_col *string_col = - reinterpret_cast(host_cols[projection_index]); - if (isNull) { - string_col->offsets.push_back(string_col->offsets.back()); - } else { - std::string data(resultValue); - string_col->chars.insert( - string_col->chars.end(), data.cbegin(), data.cend()); - string_col->offsets.push_back(string_col->offsets.back() + - data.length()); - } - break; - } - default: throw std::runtime_error("Invalid cudf type id"); - } - if (isNull) { - cudf::set_bit_unsafe(null_masks[projection_index].data(), i); - } - } - } - - cudf::io::table_with_metadata tableWithMetadata; - std::vector> cudf_columns; - cudf_columns.resize(static_cast(resultNfields)); - for (const std::size_t projection_index : column_indices) { - cudf::type_id cudf_type_id = cudf_types[projection_index]; - switch (cudf_type_id) { - case cudf::type_id::INT8: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::INT16: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::INT32: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::INT64: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::UINT8: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::UINT16: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::UINT32: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::UINT64: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - std::vector *vector = - reinterpret_cast *>(host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - std::vector *vector = - reinterpret_cast *>(host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::BOOL8: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - resultNtuples, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::STRING: { - cudf_string_col *string_col = - reinterpret_cast(host_cols[projection_index]); - cudf_columns[projection_index] = - build_str_cudf_col(string_col, null_masks[projection_index]); - break; - } - default: throw std::runtime_error("Invalid cudf type id"); - } - } - - tableWithMetadata.tbl = - std::make_unique(std::move(cudf_columns)); - - for (const std::size_t projection_index : column_indices) { - cudf::type_id cudf_type_id = cudf_types[projection_index]; - switch (cudf_type_id) { - case cudf::type_id::INT8: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::INT16: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::INT32: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::INT64: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT8: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT16: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT32: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT64: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::BOOL8: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::STRING: { - delete reinterpret_cast(host_cols[projection_index]); - break; - } - default: throw std::runtime_error("Invalid cudf type id"); - } - } - return tableWithMetadata; -} - static inline cudf::type_id -MapPostgreSQLTypeName(const std::string &columnTypeName) { +parse_postgresql_column_type(const std::string & columnTypeName) { if (postgresql_is_cudf_string(columnTypeName)) { return cudf::type_id::STRING; } - if (columnTypeName == "smallint") { return cudf::type_id::INT16; } - if (columnTypeName == "integer") { return cudf::type_id::INT32; } - if (columnTypeName == "bigint") { return cudf::type_id::INT64; } - if (columnTypeName == "decimal") { return cudf::type_id::DECIMAL64; } - if (columnTypeName == "numeric") { return cudf::type_id::DECIMAL64; } - if (columnTypeName == "real") { return cudf::type_id::FLOAT32; } - if (columnTypeName == "double precision") { return cudf::type_id::FLOAT64; } - if (columnTypeName == "smallserial") { return cudf::type_id::INT16; } - if (columnTypeName == "serial") { return cudf::type_id::INT32; } - if (columnTypeName == "bigserial") { return cudf::type_id::INT64; } - if (columnTypeName == "boolean") { return cudf::type_id::BOOL8; } - if (columnTypeName == "date") { return cudf::type_id::TIMESTAMP_DAYS; } - if (columnTypeName == "money") { return cudf::type_id::UINT64; } - if (columnTypeName == "timestamp without time zone") { - return cudf::type_id::TIMESTAMP_MICROSECONDS; + if (!columnTypeName.rfind("smallint", 0)) { return cudf::type_id::INT16; } + if (!columnTypeName.rfind("integer", 0)) { return cudf::type_id::INT32; } + if (!columnTypeName.rfind("bigint", 0)) { return cudf::type_id::INT64; } + if (!columnTypeName.rfind("decimal", 0)) { return cudf::type_id::FLOAT64; } + if (!columnTypeName.rfind("numeric", 0)) { return cudf::type_id::FLOAT64; } + if (!columnTypeName.rfind("real", 0)) { return cudf::type_id::FLOAT32; } + if (!columnTypeName.rfind("double precision", 0)) { + return cudf::type_id::FLOAT64; + } + if (!columnTypeName.rfind("smallserial", 0)) { return cudf::type_id::INT16; } + if (!columnTypeName.rfind("serial", 0)) { return cudf::type_id::INT32; } + if (!columnTypeName.rfind("bigserial", 0)) { return cudf::type_id::INT64; } + if (!columnTypeName.rfind("boolean", 0)) { return cudf::type_id::BOOL8; } + if (!columnTypeName.rfind("date", 0)) { + return cudf::type_id::TIMESTAMP_MILLISECONDS; + } + if (!columnTypeName.rfind("money", 0)) { return cudf::type_id::UINT64; } + if (!columnTypeName.rfind("timestamp without time zone", 0)) { + return cudf::type_id::TIMESTAMP_MILLISECONDS; } - if (columnTypeName == "timestamp with time zone") { - return cudf::type_id::TIMESTAMP_MICROSECONDS; + if (!columnTypeName.rfind("timestamp with time zone", 0)) { + return cudf::type_id::TIMESTAMP_MILLISECONDS; } - if (columnTypeName == "time without time zone") { - return cudf::type_id::DURATION_MICROSECONDS; + if (!columnTypeName.rfind("time without time zone", 0)) { + return cudf::type_id::DURATION_MILLISECONDS; } - if (columnTypeName == "time with time zone") { - return cudf::type_id::DURATION_MICROSECONDS; + if (!columnTypeName.rfind("time with time zone", 0)) { + return cudf::type_id::DURATION_MILLISECONDS; } - if (columnTypeName == "interval") { - return cudf::type_id::DURATION_MICROSECONDS; + if (!columnTypeName.rfind("interval", 0)) { + return cudf::type_id::DURATION_MILLISECONDS; } - if (columnTypeName == "inet") { return cudf::type_id::UINT64; } - if (columnTypeName == "USER-DEFINED") { return cudf::type_id::STRUCT; } - if (columnTypeName == "ARRAY") { return cudf::type_id::LIST; } + if (!columnTypeName.rfind("inet", 0)) { return cudf::type_id::UINT64; } + if (!columnTypeName.rfind("USER-DEFINED", 0)) { + return cudf::type_id::STRUCT; + } + if (!columnTypeName.rfind("ARRAY", 0)) { return cudf::type_id::LIST; } throw std::runtime_error("PostgreSQL type hint not found: " + columnTypeName); } -postgresql_parser::postgresql_parser() = default; +postgresql_parser::postgresql_parser() + : abstractsql_parser{DataType::POSTGRESQL} {} postgresql_parser::~postgresql_parser() = default; -std::unique_ptr -postgresql_parser::parse_batch(data_handle handle, - const Schema &schema, - std::vector column_indices, - std::vector row_groups) { - auto pgResult = handle.sql_handle.postgresql_result; - if (!pgResult) { return schema.makeEmptyBlazingTable(column_indices); } - - if (!column_indices.empty()) { - std::vector columnNames; - columnNames.reserve(column_indices.size()); - std::transform(column_indices.cbegin(), - column_indices.cend(), - std::back_inserter(columnNames), - std::bind1st(std::mem_fun(&Schema::get_name), &schema)); - - auto tableWithMetadata = read_postgresql(pgResult, - column_indices, - schema.get_dtypes(), - handle.sql_handle.column_bytes); - tableWithMetadata.metadata.column_names = columnNames; - - auto table = std::move(tableWithMetadata.tbl); - return std::make_unique( - std::move(table), tableWithMetadata.metadata.column_names); +void postgresql_parser::read_sql_loop( + void * src, + const std::vector & cudf_types, + const std::vector & column_indices, + std::vector & host_cols, + std::vector> & null_masks) { + PGresult * result = static_cast(src); + const int ntuples = PQntuples(result); + for (int rowCounter = 0; rowCounter < ntuples; rowCounter++) { + parse_sql(src, + column_indices, + cudf_types, + rowCounter, + host_cols, + null_masks); } +} + +cudf::type_id +postgresql_parser::get_cudf_type_id(const std::string & sql_column_type) { + return parse_postgresql_column_type(sql_column_type); +} + +std::uint8_t postgresql_parser::parse_cudf_int8(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::int8_t value = + static_cast(std::strtol(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_int16(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::int16_t value = + static_cast(std::strtol(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_int32(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::int32_t value = + static_cast(std::strtol(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_int64(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::int64_t value = + static_cast(std::strtoll(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_uint8(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::uint8_t value = + static_cast(std::strtoul(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_uint16(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::uint16_t value = + static_cast(std::strtol(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_uint32(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::uint32_t value = + static_cast(std::strtoul(result, &end, 10)); + v->at(row) = value; + return 1; +} + +std::uint8_t +postgresql_parser::parse_cudf_uint64(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const std::uint64_t value = + static_cast(std::strtoull(result, &end, 10)); + v->at(row) = value; + return 1; +} - return nullptr; +std::uint8_t postgresql_parser::parse_cudf_float32(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const float value = static_cast(std::strtof(result, &end)); + v->at(row) = value; + return 1; } -void postgresql_parser::parse_schema(data_handle handle, Schema &schema) { - const bool is_in_file = true; - const std::size_t columnsLength = handle.sql_handle.column_names.size(); - for (std::size_t i = 0; i < columnsLength; i++) { - const std::string &column_type = handle.sql_handle.column_types.at(i); - cudf::type_id type = MapPostgreSQLTypeName(column_type); - const std::string &name = handle.sql_handle.column_names.at(i); - schema.add_column(name, type, i, is_in_file); +std::uint8_t postgresql_parser::parse_cudf_float64(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const char * result = PQgetvalue(pgResult, row, col); + char * end; + const double value = static_cast(std::strtod(result, &end)); + v->at(row) = value; + return 1; +} + +std::uint8_t postgresql_parser::parse_cudf_bool8(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + PGresult * pgResult = static_cast(src); + if (PQgetisnull(pgResult, row, col)) { return 0; } + const std::string result = std::string{PQgetvalue(pgResult, row, col)}; + std::int8_t value; + if (result == "t") { + value = 1; + } else { + if (result == "f") { + value = 0; + } else { + return 0; + } } + v->at(row) = value; + return 1; +} + +std::uint8_t postgresql_parser::parse_cudf_timestamp_days(void * src, + std::size_t col, + std::size_t row, + cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); } -std::unique_ptr -postgresql_parser::get_metadata(std::vector handles, int offset) { - return nullptr; +std::uint8_t +postgresql_parser::parse_cudf_timestamp_seconds(void * src, + std::size_t col, + std::size_t row, + cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); } +std::uint8_t +postgresql_parser::parse_cudf_timestamp_milliseconds(void * src, + std::size_t col, + std::size_t row, + cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); +} + +std::uint8_t +postgresql_parser::parse_cudf_timestamp_microseconds(void * src, + std::size_t col, + std::size_t row, + cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); +} + +std::uint8_t +postgresql_parser::parse_cudf_timestamp_nanoseconds(void * src, + std::size_t col, + std::size_t row, + cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); +} + +std::uint8_t postgresql_parser::parse_cudf_string(void * src, + std::size_t col, + std::size_t row, + cudf_string_col * v) { + PGresult * pgResult = static_cast(src); + + if (PQgetisnull(pgResult, row, col)) { + v->offsets.push_back(v->offsets.back()); + return 0; + } + const char * result = PQgetvalue(pgResult, row, col); + std::string data = result; + + // trim spaces because postgresql store chars with padding. + Oid oid = PQftype(pgResult, col); + if (oid == InvalidOid) { throw std::runtime_error("Bad postgresql type"); } + if (oid == static_cast(1042)) { + data.erase(std::find_if(data.rbegin(), + data.rend(), + [](unsigned char c) { return !std::isspace(c); }) + .base(), + data.end()); + } + + v->chars.insert(v->chars.end(), data.cbegin(), data.cend()); + v->offsets.push_back(v->offsets.back() + data.length()); + return 1; +} + + } // namespace io } // namespace ral diff --git a/engine/src/io/data_parser/sql/PostgreSQLParser.h b/engine/src/io/data_parser/sql/PostgreSQLParser.h index 9a52d2619..4b094e3c0 100644 --- a/engine/src/io/data_parser/sql/PostgreSQLParser.h +++ b/engine/src/io/data_parser/sql/PostgreSQLParser.h @@ -1,35 +1,64 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Cristhian Alberto Gonzales Castillo - * + * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Cristhian Alberto Gonzales Castillo */ #ifndef _POSTGRESQLSQLPARSER_H_ #define _POSTGRESQLSQLPARSER_H_ -#include "io/data_parser/DataParser.h" +#include "io/data_parser/sql/AbstractSQLParser.h" namespace ral { namespace io { -class postgresql_parser : public data_parser { +class postgresql_parser : public abstractsql_parser { public: postgresql_parser(); - virtual ~postgresql_parser(); - std::unique_ptr - parse_batch(data_handle handle, - const Schema &schema, - std::vector column_indices, - std::vector row_groups) override; - - void parse_schema(data_handle handle, Schema &schema) override; +protected: + void read_sql_loop(void * src, + const std::vector & cudf_types, + const std::vector & column_indices, + std::vector & host_cols, + std::vector> & null_masks) override; - std::unique_ptr - get_metadata(std::vector handles, int offset) override; + cudf::type_id get_cudf_type_id(const std::string & sql_column_type) override; - DataType type() const override { return DataType::PARQUET; } + std::uint8_t parse_cudf_int8( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_int16( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_int32( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_int64( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint8( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint16( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint32( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint64( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_float32( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_float64( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_bool8( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_timestamp_days( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_seconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_milliseconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_microseconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_nanoseconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_string( + void *, std::size_t, std::size_t, cudf_string_col *) override; }; } /* namespace io */ diff --git a/engine/src/io/data_parser/sql/SQLiteParser.cpp b/engine/src/io/data_parser/sql/SQLiteParser.cpp index 3987bdd7d..505a2a552 100644 --- a/engine/src/io/data_parser/sql/SQLiteParser.cpp +++ b/engine/src/io/data_parser/sql/SQLiteParser.cpp @@ -1,15 +1,16 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Cristhian Alberto Gonzales Castillo */ +#include +#include + #include "SQLiteParser.h" #include "sqlcommon.h" #include "utilities/CommonOperations.h" -#include - #include "ExceptionHandling/BlazingThread.h" #include #include @@ -34,706 +35,244 @@ namespace cudf_io = cudf::io; // - NVARCHAR(100) // - TEXT // - CLOB -bool sqlite_is_cudf_string(const std::string &t) { - std::vector mysql_string_types_hints = { - "CHARACTER", - "VARCHAR", - "VARYING CHARACTER", - "NCHAR", - "NATIVE CHARACTER", - "NVARCHAR", - "TEXT", - "CLOB", - "STRING" // TODO percy ??? - }; - +static const std::array mysql_string_types_hints = { + "character", + "varchar", + "char", + "varying character", + "nchar", + "native character", + "nvarchar", + "text", + "clob", + "string" // TODO percy ??? +}; + +bool sqlite_is_cudf_string(const std::string & t) { for (auto hint : mysql_string_types_hints) { - if (StringUtil::beginsWith(t, hint)) return true; + if (StringUtil::beginsWith(t, std::string{hint})) return true; } - return false; } cudf::type_id parse_sqlite_column_type(std::string t) { - if (sqlite_is_cudf_string(t)) return cudf::type_id::STRING; std::transform( t.cbegin(), t.cend(), t.begin(), [](const std::string::value_type c) { return std::tolower(c); }); - if (t == "tinyint") { return cudf::type_id::INT8; } - if (t == "smallint") { return cudf::type_id::INT8; } - if (t == "mediumint") { return cudf::type_id::INT16; } - if (t == "int") { return cudf::type_id::INT32; } - if (t == "integer") { return cudf::type_id::INT32; } - if (t == "bigint") { return cudf::type_id::INT64; } - if (t == "unsigned big int") { return cudf::type_id::UINT64; } - if (t == "int2") { return cudf::type_id::INT16; } - if (t == "int8") { return cudf::type_id::INT64; } - if (t == "real") { return cudf::type_id::FLOAT32; } - if (t == "double") { return cudf::type_id::FLOAT64; } - if (t == "double precision") { return cudf::type_id::FLOAT64; } - if (t == "float") { return cudf::type_id::FLOAT32; } - if (t == "decimal") { return cudf::type_id::FLOAT64; } - if (t == "boolean") { return cudf::type_id::UINT8; } - if (t == "date") { return cudf::type_id::TIMESTAMP_MICROSECONDS; } - if (t == "datetime") { return cudf::type_id::TIMESTAMP_MICROSECONDS; } - -} - -std::vector -parse_sqlite_column_types(const std::vector types) { - std::vector ret; - for (auto t : types) { ret.push_back(parse_sqlite_column_type(t)); } - return ret; + if (sqlite_is_cudf_string(t)) return cudf::type_id::STRING; + if (!t.rfind("tinyint", 0)) { return cudf::type_id::INT8; } + if (!t.rfind("smallint", 0)) { return cudf::type_id::INT8; } + if (!t.rfind("mediumint", 0)) { return cudf::type_id::INT16; } + if (!t.rfind("int", 0)) { return cudf::type_id::INT32; } + if (!t.rfind("integer", 0)) { return cudf::type_id::INT32; } + if (!t.rfind("bigint", 0)) { return cudf::type_id::INT64; } + if (!t.rfind("unsigned big int", 0)) { return cudf::type_id::UINT64; } + if (!t.rfind("int2", 0)) { return cudf::type_id::INT16; } + if (!t.rfind("int8", 0)) { return cudf::type_id::INT64; } + if (!t.rfind("real", 0)) { return cudf::type_id::FLOAT32; } + if (!t.rfind("double", 0)) { return cudf::type_id::FLOAT64; } + if (!t.rfind("double precision", 0)) { return cudf::type_id::FLOAT64; } + if (!t.rfind("float", 0)) { return cudf::type_id::FLOAT32; } + if (!t.rfind("decimal", 0)) { return cudf::type_id::FLOAT64; } + if (!t.rfind("boolean", 0)) { return cudf::type_id::UINT8; } + if (!t.rfind("date", 0)) { return cudf::type_id::TIMESTAMP_MILLISECONDS; } + if (!t.rfind("datetime", 0)) { return cudf::type_id::TIMESTAMP_MILLISECONDS; } } -cudf::io::table_with_metadata -read_sqlite_v2(sqlite3_stmt *stmt, - const std::vector &column_indices, - const std::vector &cudf_types) { - const std::string sqlfPart{sqlite3_expanded_sql(stmt)}; - std::string::size_type fPos = sqlfPart.find("from"); - if (fPos == std::string::npos) { fPos = sqlfPart.find("FROM"); } - - std::ostringstream oss; - oss << "select count(*) " << sqlfPart.substr(fPos); - const std::string sqlnRows = oss.str(); - - std::size_t nRows = 0; - int err = sqlite3_exec( - sqlite3_db_handle(stmt), - sqlnRows.c_str(), - [](void *data, int count, char **rows, char **) -> int { - if (count == 1 && rows) { - *static_cast(data) = - static_cast(atoi(rows[0])); - return 0; - } - return 1; - }, - &nRows, - nullptr); - if (err != SQLITE_OK) { throw std::runtime_error("getting number of rows"); } - - std::size_t column_count = - static_cast(sqlite3_column_count(stmt)); - - std::vector host_cols; - host_cols.reserve(column_count); - const std::size_t bitmask_allocation = - cudf::bitmask_allocation_size_bytes(nRows); - const std::size_t num_words = bitmask_allocation / sizeof(cudf::bitmask_type); - std::vector> null_masks(column_count); - - std::transform( - column_indices.cbegin(), - column_indices.cend(), - std::back_inserter(host_cols), - [&cudf_types, &null_masks, num_words, nRows](const int projection_index) { - null_masks[projection_index].resize(num_words, 0); - const cudf::type_id cudf_type_id = cudf_types[projection_index]; - switch (cudf_type_id) { - case cudf::type_id::INT8: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - case cudf::type_id::INT16: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - case cudf::type_id::INT32: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - case cudf::type_id::INT64: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - case cudf::type_id::BOOL8: { - auto *vector = new std::vector; - vector->reserve(nRows); - return static_cast(vector); - } - default: - throw std::runtime_error("Invalid allocation for cudf type id"); - } - }); - - std::size_t i = 0; - while ((err = sqlite3_step(stmt)) == SQLITE_ROW) { - for (const std::size_t projection_index : column_indices) { - cudf::type_id cudf_type_id = cudf_types[projection_index]; - - const bool isNull = - sqlite3_column_type(stmt, projection_index) == SQLITE_NULL; - - switch (cudf_type_id) { - case cudf::type_id::INT8: { - break; - } - case cudf::type_id::INT16: { - break; - } - case cudf::type_id::INT32: { - const std::int32_t value = sqlite3_column_int(stmt, projection_index); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(value); - break; - } - case cudf::type_id::INT64: { - const std::int64_t value = sqlite3_column_int64(stmt, projection_index); - std::vector &vector = - *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(value); - break; - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - break; - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - const double value = sqlite3_column_double(stmt, projection_index); - std::vector &vector = *reinterpret_cast *>( - host_cols[projection_index]); - vector.push_back(value); - break; - } - case cudf::type_id::BOOL8: { - break; - } - case cudf::type_id::STRING: { - cudf_string_col *string_col = - reinterpret_cast(host_cols[projection_index]); - if (isNull) { - string_col->offsets.push_back(string_col->offsets.back()); - } else { - const unsigned char *text = - sqlite3_column_text(stmt, projection_index); - const std::string value{reinterpret_cast(text)}; - - string_col->chars.insert( - string_col->chars.end(), value.cbegin(), value.cend()); - string_col->offsets.push_back(string_col->offsets.back() + - value.length()); - } - break; - } - default: throw std::runtime_error("Invalid allocation for cudf type id"); - } - if (isNull) { - cudf::set_bit_unsafe(null_masks[projection_index].data(), i); - } - i++; - } - } - - cudf::io::table_with_metadata tableWithMetadata; - std::vector> cudf_columns; - cudf_columns.resize(column_count); - for (const std::size_t projection_index : column_indices) { - cudf::type_id cudf_type_id = cudf_types[projection_index]; - switch (cudf_type_id) { - case cudf::type_id::INT8: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::INT16: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::INT32: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::INT64: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - std::vector *vector = - reinterpret_cast *>(host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - std::vector *vector = - reinterpret_cast *>(host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::BOOL8: { - std::vector *vector = - reinterpret_cast *>( - host_cols[projection_index]); - cudf_columns[projection_index] = build_fixed_width_cudf_col( - nRows, vector, null_masks[projection_index], cudf_type_id); - break; - } - case cudf::type_id::STRING: { - cudf_string_col *string_col = - reinterpret_cast(host_cols[projection_index]); - cudf_columns[projection_index] = - build_str_cudf_col(string_col, null_masks[projection_index]); - break; - } - default: throw std::runtime_error("Invalid allocation for cudf type id"); - } - } - - tableWithMetadata.tbl = - std::make_unique(std::move(cudf_columns)); - - for (const std::size_t projection_index : column_indices) { - cudf::type_id cudf_type_id = cudf_types[projection_index]; - switch (cudf_type_id) { - case cudf::type_id::INT8: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::INT16: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::INT32: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::INT64: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT8: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT16: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT32: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::UINT64: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::FLOAT32: - case cudf::type_id::DECIMAL32: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::FLOAT64: - case cudf::type_id::DECIMAL64: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::BOOL8: { - delete reinterpret_cast *>( - host_cols[projection_index]); - break; - } - case cudf::type_id::STRING: { - delete reinterpret_cast(host_cols[projection_index]); - break; - } - default: throw std::runtime_error("Invalid cudf type id"); - } +sqlite_parser::sqlite_parser() : abstractsql_parser{DataType::SQLITE} {} + +sqlite_parser::~sqlite_parser() = default; + +void sqlite_parser::read_sql_loop(void * src, + const std::vector & cudf_types, + const std::vector & column_indices, + std::vector & host_cols, + std::vector> & null_masks) { + int rowCounter = 0; + sqlite3_stmt * stmt = reinterpret_cast(src); + while (sqlite3_step(stmt) == SQLITE_ROW) { + parse_sql( + src, column_indices, cudf_types, rowCounter, host_cols, null_masks); + ++rowCounter; } - return tableWithMetadata; } -cudf::io::table_with_metadata -read_sqlite(sqlite3_stmt *stmt, const std::vector types) { - int total_rows = 17; // TODO percy add this logic to the provider - cudf::io::table_with_metadata ret; - std::vector cudf_types = parse_sqlite_column_types(types); - std::vector> host_cols(types.size()); - std::vector> host_valids(host_cols.size()); - - for (int col = 0; col < host_cols.size(); ++col) { - host_cols[col].resize(total_rows); - } +cudf::type_id sqlite_parser::get_cudf_type_id( + const std::string & sql_column_type) { + return parse_sqlite_column_type(sql_column_type); +} - for (int col = 0; col < host_valids.size(); ++col) { - host_valids[col].resize(total_rows); - } +// To know about postgresql data type details +// see https://www.postgresql.org/docs/current/datatype.html + +std::uint8_t sqlite_parser::parse_cudf_int8(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::int8_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; +} - int row = 0; - int rc = 0; - while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { - for (int col = 0; col < cudf_types.size(); ++col) { - int mysql_col = col; - char *value = nullptr; - size_t data_size = 0; - cudf::type_id cudf_type_id = cudf_types[col]; - switch (cudf_type_id) { - case cudf::type_id::EMPTY: { - } break; - case cudf::type_id::INT8: { - - } break; - case cudf::type_id::INT16: { - - } break; - case cudf::type_id::INT32: { - int32_t type = sqlite3_column_int(stmt, mysql_col); - value = (char *) type; - data_size = sizeof(int32_t); - } break; - case cudf::type_id::INT64: { - std::int32_t type = sqlite3_column_int64(stmt, mysql_col); - value = (char *) type; - data_size = sizeof(std::int64_t); - } break; - case cudf::type_id::UINT8: { - - } break; - case cudf::type_id::UINT16: { - - } break; - case cudf::type_id::UINT32: { - - } break; - case cudf::type_id::UINT64: { - int64_t type = sqlite3_column_int64(stmt, mysql_col); - value = (char *) type; - data_size = sizeof(int64_t); - } break; - case cudf::type_id::FLOAT32: { - - } break; - case cudf::type_id::FLOAT64: { - - } break; - case cudf::type_id::BOOL8: { - - } break; - case cudf::type_id::TIMESTAMP_DAYS: { - - } break; - case cudf::type_id::TIMESTAMP_SECONDS: { - - } break; - case cudf::type_id::TIMESTAMP_MILLISECONDS: { - - } break; - case cudf::type_id::TIMESTAMP_MICROSECONDS: { - - } break; - case cudf::type_id::TIMESTAMP_NANOSECONDS: { - - } break; - case cudf::type_id::DURATION_DAYS: { - - } break; - case cudf::type_id::DURATION_SECONDS: { - - } break; - case cudf::type_id::DURATION_MILLISECONDS: { - - } break; - case cudf::type_id::DURATION_MICROSECONDS: { - - } break; - case cudf::type_id::DURATION_NANOSECONDS: { - - } break; - case cudf::type_id::DICTIONARY32: { - - } break; - case cudf::type_id::STRING: { - const unsigned char *name = sqlite3_column_text(stmt, mysql_col); - std::string tmpstr((char *) name); - data_size = tmpstr.size() + 1; // +1 for null terminating char - value = (char *) malloc(data_size); - // value = (char*)tmpstr.c_str(); - strncpy(value, tmpstr.c_str(), data_size); - } break; - case cudf::type_id::LIST: { - - } break; - case cudf::type_id::DECIMAL32: { - - } break; - case cudf::type_id::DECIMAL64: { - - } break; - case cudf::type_id::STRUCT: { - - } break; - } - host_cols[col][row] = value; - host_valids[col][row] = (value == nullptr || value == NULL) ? 0 : 1; - // std::cout << "\t\t" << res->getString("dept_no") << "\n"; - } - ++row; - } +std::uint8_t sqlite_parser::parse_cudf_int16(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::int16_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; +} - std::vector> cudf_cols(cudf_types.size()); +std::uint8_t sqlite_parser::parse_cudf_int32(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::int32_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; +} - for (int col = 0; col < cudf_cols.size(); ++col) { - switch (cudf_types[col]) { - case cudf::type_id::EMPTY: { - } break; - case cudf::type_id::INT8: { - - } break; - case cudf::type_id::INT16: { - - } break; - case cudf::type_id::INT32: { - int32_t *cols_buff = (int32_t *) host_cols[col].data(); - uint32_t *valids_buff = (uint32_t *) host_valids[col].data(); - std::vector cols(cols_buff, cols_buff + total_rows); - std::vector valids(valids_buff, valids_buff + total_rows); - //cudf::test::fixed_width_column_wrapper vals( - //cols.begin(), cols.end(), valids.begin()); - //cudf_cols[col] = std::move(vals.release()); - } break; - case cudf::type_id::INT64: { - - } break; - case cudf::type_id::UINT8: { - - } break; - case cudf::type_id::UINT16: { - - } break; - case cudf::type_id::UINT32: { - - } break; - case cudf::type_id::UINT64: { - - } break; - case cudf::type_id::FLOAT32: { - - } break; - case cudf::type_id::FLOAT64: { - - } break; - case cudf::type_id::BOOL8: { - - } break; - case cudf::type_id::TIMESTAMP_DAYS: { - // TODO percy - // int32_t *cols_buff = (int32_t*)host_cols[col].data(); - // uint32_t *valids_buff = (uint32_t*)host_valids[col].data(); - // std::vector cols(cols_buff, cols_buff + total_rows); - // std::vector valids(valids_buff, valids_buff + - // total_rows); cudf::test::fixed_width_column_wrapper - // vals(cols.begin(), cols.end(), valids.begin()); cudf::test:: - // cudf_cols[col] = std::move(vals.release()); - } break; - case cudf::type_id::TIMESTAMP_SECONDS: { - - } break; - case cudf::type_id::TIMESTAMP_MILLISECONDS: { - - } break; - case cudf::type_id::TIMESTAMP_MICROSECONDS: { - - } break; - case cudf::type_id::TIMESTAMP_NANOSECONDS: { - - } break; - case cudf::type_id::DURATION_DAYS: { - - } break; - case cudf::type_id::DURATION_SECONDS: { - - } break; - case cudf::type_id::DURATION_MILLISECONDS: { - - } break; - case cudf::type_id::DURATION_MICROSECONDS: { - - } break; - case cudf::type_id::DURATION_NANOSECONDS: { - - } break; - case cudf::type_id::DICTIONARY32: { - - } break; - case cudf::type_id::STRING: { - std::vector cols(total_rows); - for (int row_index = 0; row_index < host_cols[col].size(); ++row_index) { - void *dat = host_cols[col][row_index]; - char *strdat = (char *) dat; - std::string v(strdat); - cols[row_index] = v; - } - - // char **cols_buff = (char**)host_cols[col].data(); - // std::vector cols(cols_buff, cols_buff + total_rows); - - uint32_t *valids_buff = (uint32_t *) host_valids[col].data(); - std::vector valids(valids_buff, valids_buff + total_rows); - - //cudf::test::strings_column_wrapper vals( - //cols.begin(), cols.end(), valids.begin()); - //cudf_cols[col] = std::move(vals.release()); - } break; - case cudf::type_id::LIST: { - - } break; - case cudf::type_id::DECIMAL32: { - - } break; - case cudf::type_id::DECIMAL64: { - - } break; - case cudf::type_id::STRUCT: { - - } break; - } - // cudf::strings:: - // rmm::device_buffer values(static_cast(host_cols[col].data()), - // total_rows); rmm::device_buffer null_mask(static_cast(host_valids[col].data()), total_rows); cudf::column(cudf_types[col], - // total_rows, values.data(), null_mask.data()); - } +std::uint8_t sqlite_parser::parse_cudf_int64(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::int64_t value = + static_cast(sqlite3_column_int64(stmt, col)); + v->at(row) = value; + return 1; +} - // std::unique_ptr col = - // cudf::make_empty_column(numeric_column(cudf::data_type(cudf::type_id::INT32), - // 20); - // using DecimalTypes = cudf::test::Types; +std::uint8_t sqlite_parser::parse_cudf_uint8(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::uint8_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; +} - // std::vector dat = {5, 4, 3, 5, 8, 5, 6, 5}; - // std::vector valy = {1, 1, 1, 1, 1, 1, 1, 1}; - // cudf::test::fixed_width_column_wrapper vals(dat.begin(), - // dat.end(), valy.begin()); +std::uint8_t sqlite_parser::parse_cudf_uint16(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::uint16_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; +} - ret.tbl = std::make_unique(std::move(cudf_cols)); - return ret; +std::uint8_t sqlite_parser::parse_cudf_uint32(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::uint32_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; } -sqlite_parser::sqlite_parser() {} +std::uint8_t sqlite_parser::parse_cudf_uint64(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::uint64_t value = + static_cast(sqlite3_column_int64(stmt, col)); + v->at(row) = value; + return 1; +} -sqlite_parser::~sqlite_parser() {} +std::uint8_t sqlite_parser::parse_cudf_float32( + void * src, std::size_t col, std::size_t row, std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const float value = static_cast(sqlite3_column_double(stmt, col)); + v->at(row) = value; + return 1; +} -std::unique_ptr -sqlite_parser::parse_batch(ral::io::data_handle handle, - const Schema &schema, - std::vector column_indices, - std::vector row_groups) { - auto stmt = handle.sql_handle.sqlite_statement; - if (stmt == nullptr) { return schema.makeEmptyBlazingTable(column_indices); } +std::uint8_t sqlite_parser::parse_cudf_float64( + void * src, std::size_t col, std::size_t row, std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const double value = sqlite3_column_double(stmt, col); + v->at(row) = value; + return 1; +} - if (column_indices.size() > 0) { - std::vector col_names(column_indices.size()); +std::uint8_t sqlite_parser::parse_cudf_bool8(void * src, + std::size_t col, + std::size_t row, + std::vector * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { return 0; } + const std::int8_t value = + static_cast(sqlite3_column_int(stmt, col)); + v->at(row) = value; + return 1; +} - for (size_t column_i = 0; column_i < column_indices.size(); column_i++) { - col_names[column_i] = schema.get_name(column_indices[column_i]); - } +std::uint8_t sqlite_parser::parse_cudf_timestamp_days( + void * src, std::size_t col, std::size_t row, cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); +} - auto result = - read_sqlite_v2(stmt.get(), column_indices, schema.get_dtypes()); - result.metadata.column_names = col_names; +std::uint8_t sqlite_parser::parse_cudf_timestamp_seconds( + void * src, std::size_t col, std::size_t row, cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); +} - auto result_table = std::move(result.tbl); - if (result.metadata.column_names.size() > column_indices.size()) { - auto columns = result_table->release(); - // Assuming columns are in the same order as column_indices and any extra - // columns (i.e. index column) are put last - columns.resize(column_indices.size()); - result_table = std::make_unique(std::move(columns)); - } +std::uint8_t sqlite_parser::parse_cudf_timestamp_milliseconds( + void * src, std::size_t col, std::size_t row, cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); +} - return std::make_unique( - std::move(result_table), result.metadata.column_names); - } - return nullptr; +std::uint8_t sqlite_parser::parse_cudf_timestamp_microseconds( + void * src, std::size_t col, std::size_t row, cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); } -void sqlite_parser::parse_schema(data_handle handle, Schema &schema) { - const bool is_in_file = true; - for (int i = 0; i < handle.sql_handle.column_names.size(); i++) { - const std::string &column_type = handle.sql_handle.column_types.at(i); - cudf::type_id type = parse_sqlite_column_type(column_type); - const std::string &name = handle.sql_handle.column_names.at(i); - schema.add_column(name, type, i, is_in_file); - } +std::uint8_t sqlite_parser::parse_cudf_timestamp_nanoseconds( + void * src, std::size_t col, std::size_t row, cudf_string_col * v) { + return parse_cudf_string(src, col, row, v); } -std::unique_ptr -sqlite_parser::get_metadata(std::vector handles, int offset) { - // std::vector num_row_groups(files.size()); - // BlazingThread threads[files.size()]; - // std::vector> - // parquet_readers(files.size()); for(size_t file_index = 0; file_index < - // files.size(); file_index++) { threads[file_index] = BlazingThread([&, - // file_index]() { parquet_readers[file_index] = - // std::move(parquet::ParquetFileReader::Open(files[file_index])); - // std::shared_ptr file_metadata = - // parquet_readers[file_index]->metadata(); num_row_groups[file_index] = - // file_metadata->num_row_groups(); - // }); - // } - - // for(size_t file_index = 0; file_index < files.size(); file_index++) { - // threads[file_index].join(); - // } - - // size_t total_num_row_groups = - // std::accumulate(num_row_groups.begin(), num_row_groups.end(), - // size_t(0)); - - // auto minmax_metadata_table = get_minmax_metadata(parquet_readers, - // total_num_row_groups, offset); for (auto &reader : parquet_readers) { - // reader->Close(); - // } - // return minmax_metadata_table; +std::uint8_t sqlite_parser::parse_cudf_string( + void * src, std::size_t col, std::size_t row, cudf_string_col * v) { + sqlite3_stmt * stmt = reinterpret_cast(src); + + std::string column_decltype = sqlite3_column_decltype(stmt, col); + + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) { + v->offsets.push_back(v->offsets.back()); + return 0; + } else { + const unsigned char * text = sqlite3_column_text(stmt, col); + const std::string value{reinterpret_cast(text)}; + v->chars.insert(v->chars.end(), value.cbegin(), value.cend()); + v->offsets.push_back(v->offsets.back() + value.length()); + return 1; + } } } /* namespace io */ diff --git a/engine/src/io/data_parser/sql/SQLiteParser.h b/engine/src/io/data_parser/sql/SQLiteParser.h index 005517982..1f241d2eb 100644 --- a/engine/src/io/data_parser/sql/SQLiteParser.h +++ b/engine/src/io/data_parser/sql/SQLiteParser.h @@ -1,34 +1,64 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Cristhian Alberto Gonzales Castillo */ #ifndef _SQLITEPARSER_H_ #define _SQLITEPARSER_H_ -#include "io/data_parser/DataParser.h" +#include "io/data_parser/sql/AbstractSQLParser.h" namespace ral { namespace io { -class sqlite_parser : public data_parser { +class sqlite_parser : public abstractsql_parser { public: sqlite_parser(); - virtual ~sqlite_parser(); - std::unique_ptr - parse_batch(data_handle handle, - const Schema &schema, - std::vector column_indices, - std::vector row_groups) override; - - void parse_schema(data_handle handle, Schema &schema) override; +protected: + void read_sql_loop(void * src, + const std::vector & cudf_types, + const std::vector & column_indices, + std::vector & host_cols, + std::vector> & null_masks) override; - std::unique_ptr - get_metadata(std::vector handles, int offset) override; + cudf::type_id get_cudf_type_id(const std::string & sql_column_type) override; - DataType type() const override { return DataType::PARQUET; } + std::uint8_t parse_cudf_int8( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_int16( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_int32( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_int64( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint8( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint16( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint32( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_uint64( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_float32( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_float64( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_bool8( + void *, std::size_t, std::size_t, std::vector *) override; + std::uint8_t parse_cudf_timestamp_days( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_seconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_milliseconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_microseconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_timestamp_nanoseconds( + void *, std::size_t, std::size_t, cudf_string_col *) override; + std::uint8_t parse_cudf_string( + void *, std::size_t, std::size_t, cudf_string_col *) override; }; } /* namespace io */ diff --git a/engine/src/io/data_parser/sql/sqlcommon.h b/engine/src/io/data_parser/sql/sqlcommon.h index 84e8cc992..b36c46456 100644 --- a/engine/src/io/data_parser/sql/sqlcommon.h +++ b/engine/src/io/data_parser/sql/sqlcommon.h @@ -1,3 +1,7 @@ +/* + * Copyright 2021 Percy Camilo Triveño Aucahuasi + */ + #include #include #include diff --git a/engine/src/io/data_provider/DataProvider.h b/engine/src/io/data_provider/DataProvider.h index 268c85d51..ccab3b3d3 100644 --- a/engine/src/io/data_provider/DataProvider.h +++ b/engine/src/io/data_provider/DataProvider.h @@ -51,7 +51,6 @@ struct sql_datasource { std::string query; std::vector column_names; std::vector column_types; // always uppercase - std::vector column_bytes; size_t row_count; #ifdef MYSQL_SUPPORT std::shared_ptr mysql_resultset = nullptr; diff --git a/engine/src/io/data_provider/sql/AbstractSQLDataProvider.cpp b/engine/src/io/data_provider/sql/AbstractSQLDataProvider.cpp index 25e977575..78574ceee 100644 --- a/engine/src/io/data_provider/sql/AbstractSQLDataProvider.cpp +++ b/engine/src/io/data_provider/sql/AbstractSQLDataProvider.cpp @@ -1,13 +1,99 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Percy Camilo Triveño Aucahuasi */ #include "AbstractSQLDataProvider.h" +#include "parser/expression_utils.hpp" +#include "CalciteExpressionParsing.h" +#include "parser/expression_tree.hpp" + namespace ral { namespace io { +bool parse_blazingsql_predicate(const std::string &source, parser::parse_tree &ast) +{ + if (source.empty()) return false; + std::string filter_string = get_named_expression(source, "filters"); + if (filter_string.empty()) return false; + std::string predicate = replace_calcite_regex(filter_string); + if (predicate.empty()) return false; + ast.build(predicate); + return true; +} + +bool transform_sql_predicate(parser::parse_tree &source_ast, + ral::parser::node_transformer* predicate_transformer) +{ + // TODO percy c.gonzales try/catch and check if other conds can return false + source_ast.transform(*predicate_transformer); + return true; +} + +std::string in_order(const ral::parser::node &node) { + using ral::parser::operator_node; + if (&node == nullptr) { + return ""; + } + if (node.type == ral::parser::node_type::OPERATOR) { + ral::parser::operator_node &op_node = ((ral::parser::operator_node&)node); + operator_type op = map_to_operator_type(op_node.value); + if (is_unary_operator(op)) { + auto c1 = node.children[0].get(); + auto placement = op_node.placement; + if (placement == operator_node::END) { + auto body = in_order(*c1) + " " + op_node.label; + if (op_node.parentheses_wrap) { + return "(" + body + ")"; + } else { + return body; + } + } else if (placement == operator_node::BEGIN || placement == operator_node::AUTO) { + auto body = op_node.label + " " + in_order(*c1); + if (op_node.parentheses_wrap) { + return "(" + body + ")"; + } else { + return body; + } + } + } else if (is_binary_operator(op)) { + std::string body; + for (int i = 0; i < node.children.size(); ++i) { + auto c = node.children[i].get(); + body += in_order(*c); + if (i < node.children.size()-1) { + body += " " + op_node.label + " "; + } + } + if (op_node.parentheses_wrap) { + return "(" + body + ")"; + } else { + return body; + } + } + } + return node.value; +} + +std::string generate_sql_predicate(ral::parser::parse_tree &target_ast) { + return in_order(target_ast.root()); +} + +// NOTE +// the source ast and the target ast are based on the same struct: ral::parser::parse_tree +// for now we can use that one for this use case, but in the future it would be a good idea +// to create a dedicated ast for each target sql backend/dialect +std::string transpile_sql_predicate(const std::string &source, + ral::parser::node_transformer *predicate_transformer) +{ + parser::parse_tree ast; + if (!parse_blazingsql_predicate(source, ast)) return ""; + if (transform_sql_predicate(ast, predicate_transformer)) { + return generate_sql_predicate(ast); + } + return ""; +} + abstractsql_data_provider::abstractsql_data_provider( const sql_info &sql, size_t total_number_of_nodes, @@ -15,7 +101,7 @@ abstractsql_data_provider::abstractsql_data_provider( : data_provider(), sql(sql), total_number_of_nodes(total_number_of_nodes), self_node_idx(self_node_idx) {} abstractsql_data_provider::~abstractsql_data_provider() { - this->close_file_handles(); + this->close_file_handles(); } std::vector abstractsql_data_provider::get_some(std::size_t batch_count, bool){ @@ -38,8 +124,22 @@ void abstractsql_data_provider::close_file_handles() { // NOTE we don't use any file handle for this provider so nothing to do here } -std::string abstractsql_data_provider::build_select_from() const { +bool abstractsql_data_provider::set_predicate_pushdown(const std::string &queryString) +{ + // DEBUG + //std::cout << "\nORIGINAL query part for the predicate pushdown:\n" << queryString << "\n\n"; + auto predicate_transformer = this->get_predicate_transformer(); + this->where = transpile_sql_predicate(queryString, predicate_transformer.get()); + // DEBUG + //std::cout << "\nWHERE stmt for the predicate pushdown:\n" << this->where << "\n\n"; + return !this->where.empty(); +} + +std::string abstractsql_data_provider::build_select_query( + std::size_t batch_index, + const std::string & orderBy) const { std::string cols; + if (this->column_indices.empty()) { cols = "* "; } else { @@ -53,11 +153,29 @@ std::string abstractsql_data_provider::build_select_from() const { } } } - return "SELECT " + cols + "FROM " + this->sql.table; -} -std::string abstractsql_data_provider::build_limit_offset(size_t offset) const { - return " LIMIT " + std::to_string(this->sql.table_batch_size) + " OFFSET " + std::to_string(offset); + std::ostringstream oss; + + oss << "SELECT " << cols << " FROM " << this->sql.table; + + if (sql.table_filter.empty()) { + if (!this->where + .empty()) { // then the filter is from the predicate pushdown{ + oss << " where " << this->where; + } + } else { + oss << " where " << sql.table_filter; + } + + if (!orderBy.empty()) { oss << " order by " << orderBy; } + + const size_t offset = + this->sql.table_batch_size * + (this->total_number_of_nodes * batch_index + this->self_node_idx); + + oss << " LIMIT " << this->sql.table_batch_size << " OFFSET " << offset; + + return oss.str(); } } /* namespace io */ diff --git a/engine/src/io/data_provider/sql/AbstractSQLDataProvider.h b/engine/src/io/data_provider/sql/AbstractSQLDataProvider.h index 73f8a1de7..d40a4bc4c 100644 --- a/engine/src/io/data_provider/sql/AbstractSQLDataProvider.h +++ b/engine/src/io/data_provider/sql/AbstractSQLDataProvider.h @@ -6,6 +6,8 @@ #define ABSTRACTSQLDATAPROVIDER_H_ #include "io/data_provider/DataProvider.h" +#include "io/Schema.h" +#include "parser/expression_tree.hpp" namespace ral { namespace io { @@ -50,22 +52,24 @@ class abstractsql_data_provider : public data_provider { */ void set_column_indices(std::vector column_indices) { this->column_indices = column_indices; } + bool set_predicate_pushdown(const std::string &queryString); + protected: - // returns SELECT ... FROM - std::string build_select_from() const; + virtual std::unique_ptr get_predicate_transformer() const = 0; - // returns LIMIT ... OFFSET - std::string build_limit_offset(size_t offset) const; +protected: + // returns SELECT ... FROM ... WHERE ... LIMIT ... OFFSET + std::string build_select_query(std::size_t batch_index, + const std::string & orderBy = "") const; protected: sql_info sql; - std::pair query_parts; std::vector column_indices; std::vector column_names; std::vector column_types; - std::vector column_bytes; size_t total_number_of_nodes; size_t self_node_idx; + std::string where; }; template @@ -74,6 +78,13 @@ void set_sql_projections(data_provider *provider, const std::vector &projec sql_provider->set_column_indices(projections); } +template +bool set_sql_predicate_pushdown(data_provider *provider, const std::string &queryString) +{ + auto sql_provider = static_cast(provider); + return sql_provider->set_predicate_pushdown(queryString); +} + } /* namespace io */ } /* namespace ral */ diff --git a/engine/src/io/data_provider/sql/MySQLDataProvider.cpp b/engine/src/io/data_provider/sql/MySQLDataProvider.cpp index 005a4a805..aa7e9d2e5 100644 --- a/engine/src/io/data_provider/sql/MySQLDataProvider.cpp +++ b/engine/src/io/data_provider/sql/MySQLDataProvider.cpp @@ -31,7 +31,12 @@ struct mysql_table_info { struct mysql_columns_info { std::vector columns; std::vector types; - std::vector bytes; +}; + +struct mysql_operator_info { + std::string label; + ral::parser::operator_node::placement_type placement = ral::parser::operator_node::AUTO; + bool parentheses_wrap = true; }; /* MySQL supports these connection properties: @@ -79,7 +84,6 @@ sql::ConnectOptionsMap build_jdbc_mysql_connection(const std::string host, ret["userName"] = user; ret["password"] = password; ret["schema"] = schema; - // TODO percy set chunk size here return ret; } @@ -96,8 +100,16 @@ std::shared_ptr execute_mysql_query(sql::Connection *con, delete pointer; }; - std::shared_ptr stmt(con->createStatement(), stmt_deleter); - return std::shared_ptr(stmt->executeQuery(query), resultset_deleter); + std::shared_ptr ret = nullptr; + try { + std::shared_ptr stmt(con->createStatement(), stmt_deleter); + ret = std::shared_ptr(stmt->executeQuery(query), resultset_deleter); + // NOTE do not enable setFetchSize, we want to read by batches + } catch (sql::SQLException &e) { + throw std::runtime_error("ERROR: Could not run the MySQL query " + query + ": " + e.what()); + } + + return ret; } mysql_table_info get_mysql_table_info(sql::Connection *con, const std::string &table) { @@ -116,38 +128,12 @@ mysql_table_info get_mysql_table_info(sql::Connection *con, const std::string &t break; // we should not have more than 1 row here } } catch (sql::SQLException &e) { - // TODO percy + throw std::runtime_error("ERROR: Could not get table information for MySQL " + table + ": " + e.what()); } return ret; } -// TODO percy avoid repeated code -bool is_string_test(const std::string &t) { - std::vector mysql_string_types_hints = { - "CHAR", - "VARCHAR", - "BINARY", - "VARBINARY", - "TINYBLOB", - "TINYTEXT", - "TEXT", - "BLOB", - "MEDIUMTEXT", - "MEDIUMBLOB", - "LONGTEXT", - "LONGBLOB", - "ENUM", - "SET" - }; - - for (auto hint : mysql_string_types_hints) { - if (StringUtil::beginsWith(t, hint)) return true; - } - - return false; -} - mysql_columns_info get_mysql_columns_info(sql::Connection *con, const std::string &table) { mysql_columns_info ret; @@ -160,33 +146,93 @@ mysql_columns_info get_mysql_columns_info(sql::Connection *con, while (res->next()) { std::string col_name = res->getString("COLUMN_NAME").asStdString(); std::string col_type = StringUtil::toUpper(res->getString("COLUMN_TYPE").asStdString()); - size_t max_bytes = 8; // max bytes date = 5+3(frac secs) = 8 ... then the largest comes from strings - - if (is_string_test(col_type)) { - max_bytes = res->getUInt64("CHARACTER_MAXIMUM_LENGTH"); - } else if (col_type == "DATETIME" || col_type == "TIMESTAMP") { - // NOTE mysql jdbc represents mysql date/datetime types as strings so is better to reserve a good amount here - max_bytes = 48; - } - ret.columns.push_back(col_name); ret.types.push_back(col_type); - ret.bytes.push_back(max_bytes); } } catch (sql::SQLException &e) { - // TODO percy + throw std::runtime_error("ERROR: Could not get columns information for MySQL " + table + ": " + e.what()); } return ret; } +static std::map get_mysql_operators() { + using ral::parser::operator_node; + + // see https://dev.mysql.com/doc/refman/8.0/en/non-typed-operators.html + static std::map operators; + if (operators.empty()) { + operators[operator_type::BLZ_IS_NOT_NULL] = {.label = "IS NOT NULL", .placement = operator_node::END}; + operators[operator_type::BLZ_IS_NULL] = {.label = "IS NULL", .placement = operator_node::END}; + } + return operators; +} + +class mysql_predicate_transformer : public parser::node_transformer { +public: + mysql_predicate_transformer( + const std::vector &column_indices, + const std::vector &column_names) + : column_indices(column_indices), column_names(column_names) {} + + virtual ~mysql_predicate_transformer() {} + + parser::node * transform(parser::operad_node& node) override { + auto ndir = &((ral::parser::node&)node); + if (this->visited.count(ndir)) { + return &node; + } + if (node.type == ral::parser::node_type::VARIABLE) { + std::string var = StringUtil::split(node.value, "$")[1]; + size_t idx = std::atoi(var.c_str()); + size_t col = column_indices[idx]; + node.value = column_names[col]; + } else if (node.type == ral::parser::node_type::LITERAL) { + ral::parser::literal_node &literal_node = ((ral::parser::literal_node&)node); + if (literal_node.type().id() == cudf::type_id::TIMESTAMP_DAYS || + literal_node.type().id() == cudf::type_id::TIMESTAMP_SECONDS || + literal_node.type().id() == cudf::type_id::TIMESTAMP_NANOSECONDS || + literal_node.type().id() == cudf::type_id::TIMESTAMP_MICROSECONDS || + literal_node.type().id() == cudf::type_id::TIMESTAMP_MILLISECONDS) + { + node.value = "\"" + node.value + "\""; + } + } + this->visited[ndir] = true; + return &node; + } + + parser::node * transform(parser::operator_node& node) override { + auto ndir = &((ral::parser::node&)node); + if (this->visited.count(ndir)) { + return &node; + } + if (!get_mysql_operators().empty()) { + operator_type op = map_to_operator_type(node.value); + if (get_mysql_operators().count(op)) { + auto op_obj = get_mysql_operators().at(op); + node.label = op_obj.label; + node.placement = op_obj.placement; + node.parentheses_wrap = op_obj.parentheses_wrap; + } + } + this->visited[ndir] = true; + return &node; + } + +private: + std::map visited; + std::vector column_indices; + std::vector column_names; +}; + mysql_data_provider::mysql_data_provider( const sql_info &sql, size_t total_number_of_nodes, size_t self_node_idx) : abstractsql_data_provider(sql, total_number_of_nodes, self_node_idx) , mysql_connection(nullptr), estimated_table_row_count(0) - , batch_position(0), table_fetch_completed(false) + , batch_index(0), table_fetch_completed(false) { sql::Driver *driver = sql::mysql::get_driver_instance(); sql::ConnectOptionsMap options = build_jdbc_mysql_connection(this->sql.host, @@ -202,7 +248,6 @@ mysql_data_provider::mysql_data_provider( this->sql.table); this->column_names = cols_info.columns; this->column_types = cols_info.types; - this->column_bytes = cols_info.bytes; } mysql_data_provider::~mysql_data_provider() { @@ -218,7 +263,7 @@ bool mysql_data_provider::has_next() { void mysql_data_provider::reset() { this->table_fetch_completed = false; - this->batch_position = 0; + this->batch_index = 0; } data_handle mysql_data_provider::get_next(bool open_file) { @@ -231,14 +276,11 @@ data_handle mysql_data_provider::get_next(bool open_file) { return ret; } - std::string select_from = this->build_select_from(); - std::string where = this->sql.table_filter.empty()? "" : " where "; - - size_t offset = this->sql.table_batch_size * (this->total_number_of_nodes * this->batch_position + this->self_node_idx); - std::string query = select_from + where + this->sql.table_filter + this->build_limit_offset(offset); + std::string query = this->build_select_query(this->batch_index); + // DEBUG - //std::cout << ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>MYSQL QUERY:\n\n" << query << "\n\n\n"; - ++this->batch_position; + //std::cout << "\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>MYSQL QUERY:\n\n" << query << "\n\n\n"; + ++this->batch_index; auto res = execute_mysql_query(this->mysql_connection.get(), query); if (res->rowsCount() == 0) { @@ -248,7 +290,6 @@ data_handle mysql_data_provider::get_next(bool open_file) { ret.sql_handle.table = this->sql.table; ret.sql_handle.column_names = this->column_names; ret.sql_handle.column_types = this->column_types; - ret.sql_handle.column_bytes = this->column_bytes; ret.sql_handle.mysql_resultset = res; ret.sql_handle.row_count = res->rowsCount(); // TODO percy add columns to uri.query @@ -261,5 +302,13 @@ size_t mysql_data_provider::get_num_handles() { return ret == 0? 1 : ret; } +std::unique_ptr mysql_data_provider::get_predicate_transformer() const +{ + return std::unique_ptr(new mysql_predicate_transformer( + this->column_indices, + this->column_names + )); +} + } /* namespace io */ } /* namespace ral */ diff --git a/engine/src/io/data_provider/sql/MySQLDataProvider.h b/engine/src/io/data_provider/sql/MySQLDataProvider.h index d2b1f587f..5a899a70f 100644 --- a/engine/src/io/data_provider/sql/MySQLDataProvider.h +++ b/engine/src/io/data_provider/sql/MySQLDataProvider.h @@ -49,10 +49,13 @@ class mysql_data_provider : public abstractsql_data_provider { */ size_t get_num_handles() override; +protected: + std::unique_ptr get_predicate_transformer() const override; + private: std::unique_ptr mysql_connection; size_t estimated_table_row_count; - size_t batch_position; + size_t batch_index; bool table_fetch_completed; }; diff --git a/engine/src/io/data_provider/sql/PostgreSQLDataProvider.cpp b/engine/src/io/data_provider/sql/PostgreSQLDataProvider.cpp index ada72973f..8ce27b06f 100644 --- a/engine/src/io/data_provider/sql/PostgreSQLDataProvider.cpp +++ b/engine/src/io/data_provider/sql/PostgreSQLDataProvider.cpp @@ -1,7 +1,6 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Cristhian Alberto Gonzales Castillo - * + * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Cristhian Alberto Gonzales Castillo */ #include @@ -15,14 +14,14 @@ namespace io { namespace { -const std::string MakePostgreSQLConnectionString(const sql_info &sql) { +const std::string MakePostgreSQLConnectionString(const sql_info & sql) { std::ostringstream os; os << "host=" << sql.host << " port=" << sql.port << " dbname=" << sql.schema << " user=" << sql.user << " password=" << sql.password; return os.str(); } -const std::string MakeQueryForColumnsInfo(const sql_info &sql) { +const std::string MakeQueryForColumnsInfo(const sql_info & sql) { std::ostringstream os; os << "select column_name, data_type, character_maximum_length" " from information_schema.tables as tables" @@ -37,11 +36,11 @@ class TableInfo { public: std::vector column_names; std::vector column_types; - std::vector column_bytes; + std::size_t row_count; }; -TableInfo ExecuteTableInfo(PGconn *connection, const sql_info &sql) { - PGresult *result = PQexec(connection, MakeQueryForColumnsInfo(sql).c_str()); +inline TableInfo ExecuteTableInfo(PGconn * connection, const sql_info & sql) { + PGresult * result = PQexec(connection, MakeQueryForColumnsInfo(sql).c_str()); if (PQresultStatus(result) != PGRES_TUPLES_OK) { PQclear(result); PQfinish(connection); @@ -64,64 +63,161 @@ TableInfo ExecuteTableInfo(PGconn *connection, const sql_info &sql) { std::string{PQgetvalue(result, i, dataTypeFn)}); // NOTE character_maximum_length is used for char or byte string type - if (PQgetisnull(result, i, characterMaximumLengthFn)) { - // TODO(recy, cristhian): check the minimum size for types - tableInfo.column_bytes.emplace_back(8); - } else { - const char *characterMaximumLengthBytes = + if (!PQgetisnull(result, i, characterMaximumLengthFn)) { + const char * characterMaximumLengthBytes = PQgetvalue(result, i, characterMaximumLengthFn); // NOTE postgresql representation of number is in network order const std::uint32_t characterMaximumLength = ntohl(*reinterpret_cast( characterMaximumLengthBytes)); - tableInfo.column_bytes.emplace_back( - static_cast(characterMaximumLength)); } } - + PQclear(result); + const std::string query = "select count(*) from " + sql.table; + result = PQexec(connection, query.c_str()); + if (PQresultStatus(result) != PGRES_TUPLES_OK) { + PQclear(result); + PQfinish(connection); + throw std::runtime_error("Error access for columns info"); + } + const char * value = PQgetvalue(result, 0, 0); + char * end; + tableInfo.row_count = std::strtoll(value, &end, 10); + PQclear(result); return tableInfo; } } // namespace -postgresql_data_provider::postgresql_data_provider(const sql_info &sql, - size_t total_number_of_nodes, - size_t self_node_idx) - : abstractsql_data_provider(sql, total_number_of_nodes, self_node_idx), table_fetch_completed{false}, - batch_position{0}, estimated_table_row_count{0} { +static inline std::string FindKeyName(PGconn * connection, + const sql_info & sql) { + // This function exists because when we get batches from table we use LIMIT + // clause and since postgresql returns unpredictable subsets of query's rows, + // we apply a group by a column in order to keep some order for result query + // see https://www.postgresql.org/docs/13/queries-limit.html + std::ostringstream oss; + oss << "select column_name, ordinal_position" + " from information_schema.table_constraints tc" + " join information_schema.key_column_usage kcu" + " on tc.constraint_name = kcu.constraint_name" + " and tc.constraint_schema = kcu.constraint_schema" + " and tc.constraint_name = kcu.constraint_name" + " where tc.table_catalog = '" + << sql.schema << "' and tc.table_name = '" << sql.table + << "' and constraint_type = 'PRIMARY KEY'"; + const std::string query = oss.str(); + PGresult * result = PQexec(connection, query.c_str()); + if (PQresultStatus(result) != PGRES_TUPLES_OK) { + PQclear(result); + PQfinish(connection); + throw std::runtime_error("Error access for columns info"); + } + + if (PQntuples(result)) { + int columnNameFn = PQfnumber(result, "column_name"); + const std::string columnName{PQgetvalue(result, 0, columnNameFn)}; + PQclear(result); + if (columnName.empty()) { + throw std::runtime_error("No column name into result for primary key"); + } else { + return columnName; + } + } else { + // here table doesn't have a primary key, so we choose a column by type + // the primitive types like int or float have priority over other types + PQclear(result); + std::ostringstream oss; + oss << "select column_name, oid, case" + " when typname like 'int_' then 1" + " when typname like 'float_' then 2" + " else 99 end as typorder" + " from information_schema.tables as tables" + " join information_schema.columns as columns" + " on tables.table_name = columns.table_name" + " join pg_type on udt_name = typname where tables.table_catalog = '" + << sql.schema << "' and tables.table_name = '" << sql.table + << "' order by typorder, typlen desc, oid"; + + const std::string query = oss.str(); + PGresult * result = PQexec(connection, query.c_str()); + if (PQresultStatus(result) != PGRES_TUPLES_OK) { + PQclear(result); + PQfinish(connection); + throw std::runtime_error("Error access for columns info"); + } + + if (PQntuples(result)) { + int columnNameFn = PQfnumber(result, "column_name"); + const std::string columnName{PQgetvalue(result, 0, columnNameFn)}; + PQclear(result); + if (columnName.empty()) { + throw std::runtime_error("No column name into result for column type"); + } else { + return columnName; + } + } + PQclear(result); + } + throw std::runtime_error("There is no a key name candidate"); +} + +static inline bool IsThereNext(PGconn * connection, const std::string & query) { + std::ostringstream oss; + oss << "select count(*) from (" << query << ") as t"; + const std::string count = oss.str(); + PGresult * result = PQexec(connection, count.c_str()); + if (PQresultStatus(result) != PGRES_TUPLES_OK) { + PQclear(result); + PQfinish(connection); + throw std::runtime_error("Count query batch"); + } + + const char * data = PQgetvalue(result, 0, 0); + char * end; + const std::size_t value = + static_cast(std::strtoll(data, &end, 10)); + PQclear(result); + + return value == 0; +} + +postgresql_data_provider::postgresql_data_provider( + const sql_info & sql, + std::size_t total_number_of_nodes, + std::size_t self_node_idx) + : abstractsql_data_provider(sql, total_number_of_nodes, self_node_idx), + table_fetch_completed{false}, batch_position{0}, + estimated_table_row_count{0} { connection = PQconnectdb(MakePostgreSQLConnectionString(sql).c_str()); if (PQstatus(connection) != CONNECTION_OK) { - std::cerr << "Connection to database failed: " << PQerrorMessage(connection) - << std::endl; throw std::runtime_error("Connection to database failed: " + std::string{PQerrorMessage(connection)}); } - std::cout << "PostgreSQL version: " << PQserverVersion(connection) - << std::endl; - TableInfo tableInfo = ExecuteTableInfo(connection, sql); - column_names = tableInfo.column_names; column_types = tableInfo.column_types; - column_bytes = tableInfo.column_bytes; + estimated_table_row_count = tableInfo.row_count; + keyname = FindKeyName(connection, sql); } postgresql_data_provider::~postgresql_data_provider() { PQfinish(connection); } std::shared_ptr postgresql_data_provider::clone() { return std::static_pointer_cast( - std::make_shared(sql, this->total_number_of_nodes, this->self_node_idx)); + std::make_shared(sql, + this->total_number_of_nodes, + this->self_node_idx)); } bool postgresql_data_provider::has_next() { - return table_fetch_completed == false; + return this->table_fetch_completed == false; } void postgresql_data_provider::reset() { - table_fetch_completed = true; - batch_position = 0; + this->table_fetch_completed = false; + this->batch_position = 0; } data_handle postgresql_data_provider::get_next(bool open_file) { @@ -131,28 +227,34 @@ data_handle postgresql_data_provider::get_next(bool open_file) { handle.sql_handle.column_names = column_names; handle.sql_handle.column_types = column_types; - if (!open_file) { return handle; } - - const std::string select_from = build_select_from(); - const std::string where = sql.table_filter.empty() ? "" : " where "; - const std::string query = select_from + where + sql.table_filter + - build_limit_offset(batch_position); + if (open_file == false) { return handle; } - batch_position += sql.table_batch_size; - PGresult *result = PQexecParams( - connection, query.c_str(), 0, nullptr, nullptr, nullptr, nullptr, 1); + std::ostringstream oss; + oss << build_select_query(batch_position, keyname); + const std::string query = oss.str(); + batch_position++; + PGresult * result = PQexec(connection, query.c_str()); if (PQresultStatus(result) != PGRES_TUPLES_OK) { PQclear(result); PQfinish(connection); throw std::runtime_error("Error getting next batch from postgresql"); } + PQflush(connection); int resultNtuples = PQntuples(result); + { + std::ostringstream oss; + oss << "QUERY: " << query << std::endl + << "COUNT: " << resultNtuples << std::endl; + std::cout << oss.str(); + } - if (!resultNtuples) { table_fetch_completed = true; } + if (!resultNtuples || + IsThereNext(connection, build_select_query(batch_position, keyname))) { + table_fetch_completed = true; + } - handle.sql_handle.column_bytes = column_bytes; handle.sql_handle.postgresql_result.reset(result, PQclear); handle.sql_handle.row_count = PQntuples(result); handle.uri = Uri("postgresql", "", sql.schema + "/" + sql.table, "", ""); diff --git a/engine/src/io/data_provider/sql/PostgreSQLDataProvider.h b/engine/src/io/data_provider/sql/PostgreSQLDataProvider.h index e3bc863c9..277e2936c 100644 --- a/engine/src/io/data_provider/sql/PostgreSQLDataProvider.h +++ b/engine/src/io/data_provider/sql/PostgreSQLDataProvider.h @@ -1,6 +1,7 @@ /* * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Cristhian Alberto Gonzales Castillo + * Copyright 2021 Cristhian Alberto Gonzales Castillo + * */ #ifndef POSTGRESQLDATAPROVIDER_H_ @@ -15,9 +16,9 @@ namespace io { class postgresql_data_provider : public abstractsql_data_provider { public: - postgresql_data_provider(const sql_info &sql, - size_t total_number_of_nodes, - size_t self_node_idx); + postgresql_data_provider(const sql_info & sql, + std::size_t total_number_of_nodes, + std::size_t self_node_idx); virtual ~postgresql_data_provider(); @@ -31,11 +32,16 @@ class postgresql_data_provider : public abstractsql_data_provider { std::size_t get_num_handles() override; +protected: + // TODO percy c.gonzales + std::unique_ptr get_predicate_transformer() const override { return nullptr; } + private: - PGconn *connection; + PGconn * connection; bool table_fetch_completed; std::size_t batch_position; std::size_t estimated_table_row_count; + std::string keyname; }; } /* namespace io */ diff --git a/engine/src/io/data_provider/sql/SQLiteDataProvider.cpp b/engine/src/io/data_provider/sql/SQLiteDataProvider.cpp index 28a3c573f..339bb07fa 100644 --- a/engine/src/io/data_provider/sql/SQLiteDataProvider.cpp +++ b/engine/src/io/data_provider/sql/SQLiteDataProvider.cpp @@ -1,6 +1,5 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Percy Camilo Triveño Aucahuasi */ // NOTES @@ -12,6 +11,8 @@ - sql::SQLException (derived from std::runtime_error) */ +#include + #include "SQLiteDataProvider.h" #include "blazingdb/io/Util/StringUtil.h" @@ -28,11 +29,10 @@ struct sqlite_table_info { struct sqlite_columns_info { std::vector columns; std::vector types; - std::vector bytes; }; struct callb { - int sqlite_callback(void *NotUsed, int argc, char **argv, char **azColName) { + int sqlite_callback(void *, int argc, char ** argv, char ** azColName) { int i; for (i = 0; i < argc; i++) { printf("%s = %s\n", azColName[i], argv[i] ? argv[i] : "NULL"); @@ -42,32 +42,32 @@ struct callb { } }; -std::shared_ptr execute_sqlite_query(sqlite3 *conn, - const std::string &query) { - sqlite3_stmt *stmt; - const char *sql = query.c_str(); - int rc = sqlite3_prepare_v2(conn, sql, -1, &stmt, NULL); - if (rc != SQLITE_OK) { - printf("error: %s", sqlite3_errmsg(conn)); - // TODO percy error +static inline std::shared_ptr +execute_sqlite_query(sqlite3 * db, const std::string & query) { + sqlite3_stmt * stmt; + + int errorCode = sqlite3_prepare_v2(db, query.c_str(), -1, &stmt, nullptr); + if (errorCode != SQLITE_OK) { + std::ostringstream oss; + oss << "Executing SQLite query provider: " << std::endl + << "query: " << query << std::endl + << "error message: " << sqlite3_errmsg(db); + throw std::runtime_error{oss.str()}; } - auto sqlite_deleter = [](sqlite3_stmt *pointer) { - std::cout << "sqlite smt deleted!!!!\n"; - sqlite3_finalize(pointer); - }; - std::shared_ptr ret(stmt, sqlite_deleter); - return ret; + + return std::shared_ptr{stmt, sqlite3_finalize}; } -sqlite_table_info get_sqlite_table_info(sqlite3 *db, const std::string &table) { +static inline sqlite_table_info +get_sqlite_table_info(sqlite3 * db, const std::string & table) { sqlite_table_info ret; const std::string sql{"select count(*) from " + table}; int err = sqlite3_exec( db, sql.c_str(), - [](void *data, int count, char **rows, char **) -> int { + [](void * data, int count, char ** rows, char **) -> int { if (count == 1 && rows) { - sqlite_table_info &ret = *static_cast(data); + sqlite_table_info & ret = *static_cast(data); ret.partitions.push_back("default"); // check for partitions api ret.rows = static_cast(atoi(rows[0])); return 0; @@ -76,135 +76,158 @@ sqlite_table_info get_sqlite_table_info(sqlite3 *db, const std::string &table) { }, &ret, nullptr); - if (err != SQLITE_OK) { throw std::runtime_error("getting number of rows"); } - return ret; -} - -// TODO percy avoid code duplication -bool sqlite_is_string_col_type(const std::string &t) { - std::vector mysql_string_types_hints = { - "CHARACTER", - "VARCHAR", - "VARYING CHARACTER", - "NCHAR", - "NATIVE CHARACTER", - "NVARCHAR", - "TEXT", - "CLOB", - "STRING" // TODO percy ??? - }; - - for (auto hint : mysql_string_types_hints) { - if (StringUtil::beginsWith(t, hint)) return true; + if (err != SQLITE_OK) { + throw std::runtime_error{std::string{"getting number of rows"} + + sqlite3_errmsg(db)}; } - - return false; + return ret; } -sqlite_columns_info get_sqlite_columns_info(sqlite3 *conn, - const std::string &table) { - // TODO percy error handling - +static inline sqlite_columns_info +get_sqlite_columns_info(sqlite3 * db, const std::string & table) { sqlite_columns_info ret; std::string query = "PRAGMA table_info(" + table + ")"; - auto A = execute_sqlite_query(conn, query); - sqlite3_stmt *stmt = A.get(); + auto stmt_ptr = execute_sqlite_query(db, query); + sqlite3_stmt * stmt = stmt_ptr.get(); - int rc = 0; + int rc = SQLITE_ERROR; while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { - const unsigned char *name = sqlite3_column_text(stmt, 1); - std::string col_name((char *) name); + const unsigned char * name = sqlite3_column_text(stmt, 1); + std::string col_name(reinterpret_cast(name)); ret.columns.push_back(col_name); - const unsigned char *type = sqlite3_column_text(stmt, 2); - std::string col_type((char *) type); + const unsigned char * type = sqlite3_column_text(stmt, 2); + std::string col_type(reinterpret_cast(type)); + + std::transform( + col_type.cbegin(), + col_type.cend(), + col_type.begin(), + [](const std::string::value_type c) { return std::tolower(c); }); - size_t max_bytes = 8; // TODO percy check max scalar bytes from sqlite - if (sqlite_is_string_col_type(col_type)) { - //max_bytes = res->getUInt64("CHARACTER_MAXIMUM_LENGTH"); - // TODO percy see how to get the max size for string/txt cols ... see docs - max_bytes = 256; - } ret.types.push_back(col_type); } + if (rc != SQLITE_DONE) { - printf("error: %s", sqlite3_errmsg(conn)); - // TODO percy error + std::ostringstream oss; + oss << "Getting SQLite columns info: " << std::endl + << "query: " << query << std::endl + << "error message: " << sqlite3_errmsg(db); + throw std::runtime_error{oss.str()}; } return ret; } -sqlite_data_provider::sqlite_data_provider(const sql_info &sql, size_t total_number_of_nodes, +sqlite_data_provider::sqlite_data_provider(const sql_info & sql, + size_t total_number_of_nodes, size_t self_node_idx) - : abstractsql_data_provider(sql, total_number_of_nodes, self_node_idx), sqlite_connection(nullptr), - batch_position(0), current_row_count(0) { - sqlite3 *conn = nullptr; - int rc = sqlite3_open(sql.schema.c_str(), &conn); - - if (rc) { - fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(conn)); - // TODO percy error - } else { - fprintf(stdout, "Opened sqlite database successfully\n"); + : abstractsql_data_provider{sql, total_number_of_nodes, self_node_idx}, + db{nullptr}, batch_position{0} { + int errorCode = sqlite3_open(sql.schema.c_str(), &db); + + if (errorCode != SQLITE_OK) { + throw std::runtime_error(std::string{"Can't open database: "} + + sqlite3_errmsg(db)); } - this->sqlite_connection = conn; - sqlite_table_info tbl_info = get_sqlite_table_info(conn, this->sql.table); - this->partitions = std::move(tbl_info.partitions); - this->row_count = tbl_info.rows; - sqlite_columns_info cols_info = - get_sqlite_columns_info(conn, this->sql.table); - this->column_names = cols_info.columns; - this->column_types = cols_info.types; - this->column_bytes = cols_info.bytes; -} + sqlite_table_info tbl_info = get_sqlite_table_info(db, sql.table); + partitions = std::move(tbl_info.partitions); + row_count = tbl_info.rows; -sqlite_data_provider::~sqlite_data_provider() { - sqlite3_close(this->sqlite_connection); + sqlite_columns_info cols_info = get_sqlite_columns_info(db, sql.table); + column_names = cols_info.columns; + column_types = cols_info.types; } +sqlite_data_provider::~sqlite_data_provider() { sqlite3_close(db); } + std::shared_ptr sqlite_data_provider::clone() { - return std::make_shared(this->sql, this->total_number_of_nodes, this->self_node_idx); + return std::make_shared(sql, + total_number_of_nodes, + self_node_idx); } bool sqlite_data_provider::has_next() { - return this->current_row_count < row_count; + // We need this implementation here becuase SQLite doesn't have a method to + // get the length of rows into a sqlite3_statement + const std::size_t offset = + sql.table_batch_size * + (batch_position * total_number_of_nodes + self_node_idx); + std::ostringstream oss; + oss << "SELECT * FROM " << sql.table << " LIMIT 1 OFFSET " << offset; + const std::string query = oss.str(); + bool it_has = false; + int errorCode = sqlite3_exec( + db, + query.c_str(), + [](void * data, int count, char ** rows, char **) -> int { + *static_cast(data) = count > 0 && rows; + return 0; + }, + &it_has, + nullptr); + if (errorCode != SQLITE_OK) { + throw std::runtime_error{std::string{"Has next SQLite batch: "} + + sqlite3_errmsg(db)}; + } + return it_has; } -void sqlite_data_provider::reset() { this->batch_position = 0; } +void sqlite_data_provider::reset() { batch_position = 0; } -data_handle sqlite_data_provider::get_next(bool) { - std::string query; +static inline std::size_t get_size_for_statement(sqlite3_stmt * stmt) { + std::ostringstream oss; + oss << "select count(*) from (" << sqlite3_expanded_sql(stmt) << ')' + << std::endl; + std::string query = oss.str(); - query = "SELECT * FROM " + this->sql.table + " LIMIT " + - std::to_string(this->batch_position + this->sql.table_batch_size) + - " OFFSET " + std::to_string(this->batch_position); - this->batch_position += this->sql.table_batch_size; + std::size_t nRows = 0; + const std::int32_t errorCode = sqlite3_exec( + sqlite3_db_handle(stmt), + query.c_str(), + [](void * data, int count, char ** rows, char **) -> int { + if (count == 1 && rows) { + *static_cast(data) = + static_cast(std::atoi(rows[0])); + return 0; + } + return 1; + }, + &nRows, + nullptr); + if (errorCode != SQLITE_OK) { + throw std::runtime_error{std::string{"Has next SQLite batch: "} + + sqlite3_errstr(errorCode)}; + } + return nRows; +} - std::cout << "query: " << query << "\n"; - auto stmt = execute_sqlite_query(this->sqlite_connection, query); - current_row_count += batch_position; +data_handle sqlite_data_provider::get_next(bool open_file) { data_handle ret; - ret.sql_handle.table = this->sql.table; - ret.sql_handle.column_names = this->column_names; - ret.sql_handle.column_types = this->column_types; - ret.sql_handle.column_bytes = this->column_bytes; + + ret.sql_handle.table = sql.table; + ret.sql_handle.column_names = column_names; + ret.sql_handle.column_types = column_types; + + if (open_file == false) { return ret; } + + const std::string query = build_select_query(batch_position); + batch_position++; + + std::shared_ptr stmt = execute_sqlite_query(db, query); + + ret.sql_handle.row_count = get_size_for_statement(stmt.get()); ret.sql_handle.sqlite_statement = stmt; + // TODO percy add columns to uri.query - ret.uri = Uri("mysql", "", this->sql.schema + "/" + this->sql.table, "", ""); - // std::cout << "get_next TOTAL rows: " << this->row_count << "\n"; - // std::cout << "get_next current_row_count: " << this->current_row_count << "\n"; + ret.uri = Uri("sqlite", "", sql.schema + "/" + sql.table, "", ""); return ret; } size_t sqlite_data_provider::get_num_handles() { - if (this->partitions.empty()) { - size_t ret = this->row_count / this->sql.table_batch_size; - return ret == 0 ? 1 : ret; - } - - return this->partitions.size(); + std::size_t ret = row_count / sql.table_batch_size; + return ret == 0 ? 1 : ret; } } /* namespace io */ diff --git a/engine/src/io/data_provider/sql/SQLiteDataProvider.h b/engine/src/io/data_provider/sql/SQLiteDataProvider.h index b7a62ab73..b2568b9c6 100644 --- a/engine/src/io/data_provider/sql/SQLiteDataProvider.h +++ b/engine/src/io/data_provider/sql/SQLiteDataProvider.h @@ -1,6 +1,6 @@ /* - * Copyright 2021 BlazingDB, Inc. - * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Percy Camilo Triveño Aucahuasi + * Copyright 2021 Cristhian Alberto Gonzales Castillo */ #ifndef SQLITEDATAPROVIDER_H_ @@ -20,41 +20,45 @@ namespace io { */ class sqlite_data_provider : public abstractsql_data_provider { public: - sqlite_data_provider(const sql_info &sql, size_t total_number_of_nodes, - size_t self_node_idx); + sqlite_data_provider(const sql_info & sql, + std::size_t total_number_of_nodes, + std::size_t self_node_idx); virtual ~sqlite_data_provider(); - std::shared_ptr clone() override; + std::shared_ptr clone() override; /** - * tells us if this provider can generate more sql resultsets - */ - bool has_next() override; + * tells us if this provider can generate more sql resultsets + */ + bool has_next() override; /** - * Resets file read count to 0 for file based DataProvider - */ - void reset() override; + * Resets file read count to 0 for file based DataProvider + */ + void reset() override; /** - * gets us the next arrow::io::RandomAccessFile - * if open_file is false will not run te query and just returns a data_handle - * with columns info - */ - data_handle get_next(bool = true) override; + * gets us the next arrow::io::RandomAccessFile + * if open_file is false will not run te query and just returns a data_handle + * with columns info + */ + data_handle get_next(bool = true) override; /** - * Get the number of data_handles that will be provided. - */ - size_t get_num_handles() override; + * Get the number of data_handles that will be provided. + */ + size_t get_num_handles() override; + +protected: + // TODO percy c.gonzales + std::unique_ptr get_predicate_transformer() const override { return nullptr; } private: - sqlite3* sqlite_connection; + sqlite3 * db; std::vector partitions; - size_t row_count; - size_t batch_position; - size_t current_row_count; + std::size_t row_count; + std::size_t batch_position; }; } /* namespace io */ diff --git a/engine/src/parser/expression_tree.hpp b/engine/src/parser/expression_tree.hpp index 2c3584927..308138038 100644 --- a/engine/src/parser/expression_tree.hpp +++ b/engine/src/parser/expression_tree.hpp @@ -79,7 +79,14 @@ struct variable_node : public operad_node { }; struct operator_node : public node { - operator_node(const std::string& value) : node{ node_type::OPERATOR, value } {}; + enum placement_type { + AUTO, + BEGIN, + MIDDLE, + END + }; + + operator_node(const std::string& value) : node{ node_type::OPERATOR, value }, label(value) {}; node * clone() const override { node * ret = new operator_node(this->value); @@ -108,6 +115,10 @@ struct operator_node : public node { } return transformer.transform(*this); } + + placement_type placement = placement_type::AUTO; + std::string label; + bool parentheses_wrap = true; }; namespace detail { diff --git a/engine/tests/provider/sql_provider_test.cpp b/engine/tests/provider/sql_provider_test.cpp index 70a2fc0ef..12ebef981 100644 --- a/engine/tests/provider/sql_provider_test.cpp +++ b/engine/tests/provider/sql_provider_test.cpp @@ -20,7 +20,7 @@ TEST_F(SQLProviderTest, DISABLED_postgresql_select_all) { sql.user = "myadmin"; sql.password = ""; sql.schema = "pagila"; - sql.table = "prueba4"; + sql.table = "prueba5"; sql.table_filter = ""; sql.table_batch_size = 2000; @@ -53,13 +53,13 @@ TEST_F(SQLProviderTest, DISABLED_postgresql_select_all) { std::cout << "SCHEMA" << std::endl << " length = " << schema.get_num_columns() << std::endl << " columns" << std::endl; - for (std::size_t i = 0; i < schema.get_num_columns(); i++) { - const std::string &name = schema.get_name(i); + for(std::size_t i = 0; i < schema.get_num_columns(); i++) { + const std::string & name = schema.get_name(i); std::cout << " " << name << ": "; try { const std::string dtypename = dt2name[schema.get_dtype(i)]; std::cout << dtypename << std::endl; - } catch (std::exception &) { + } catch(std::exception &) { std::cout << static_cast(schema.get_dtype(i)) << std::endl; } } @@ -67,7 +67,7 @@ TEST_F(SQLProviderTest, DISABLED_postgresql_select_all) { auto num_cols = schema.get_num_columns(); std::vector column_indices(num_cols); - std::iota(column_indices.begin(), column_indices.end(), 0); + std::iota(column_indices.begin(), column_indices.end(), 0); std::vector row_groups; auto table = parser.parse_batch(handle, schema, column_indices, row_groups); @@ -78,19 +78,23 @@ TEST_F(SQLProviderTest, DISABLED_postgresql_select_all) { } -void print_batch(const ral::io::data_handle &handle, - const ral::io::Schema &schema, - ral::io::mysql_parser &parser, - const std::vector &column_indices) { +void print_batch(const ral::io::data_handle & handle, + const ral::io::Schema & schema, + ral::io::mysql_parser & parser, + const std::vector & column_indices) { std::vector row_groups; - std::unique_ptr bztbl = parser.parse_batch(handle, schema, column_indices, row_groups); - static int i = 0; - ral::utilities::print_blazing_table_view(bztbl->toBlazingTableView(), "holis"+std::to_string(++i)); + std::unique_ptr bztbl = + parser.parse_batch(handle, schema, column_indices, row_groups); + static int i = 0; + ral::utilities::print_blazing_table_view( + bztbl->toBlazingTableView(), "holis" + std::to_string(++i)); + std::cout << "TREMINO DE IMPRIMER CUDF TABLE!!! \n"; } TEST_F(SQLProviderTest, DISABLED_mysql_select_all) { ral::io::sql_info sql; sql.host = "localhost"; + //sql.port = 5432; // pg sql.port = 3306; // sql.user = "blazing"; // sql.password = "admin"; @@ -111,41 +115,51 @@ TEST_F(SQLProviderTest, DISABLED_mysql_select_all) { sql.schema = "tpch"; - //sql.table = "lineitem"; - sql.table = "nation"; + sql.table = "lineitem"; + //sql.table = "nation"; + //sql.table = "orders"; sql.table_filter = ""; - sql.table_batch_size = 2000; + sql.table_batch_size = 200000; sql.table_batch_size = 2; - auto mysql_provider = std::make_shared(sql, 1, 0); + auto mysql_provider = + std::make_shared(sql, 1, 0); int rows = mysql_provider->get_num_handles(); ral::io::mysql_parser parser; ral::io::Schema schema; auto handle = - mysql_provider->get_next(false); // false so we make sure dont go to the - // db and get the schema info only + mysql_provider->get_next(false); // false so we make sure dont go to + // the db and get the schema info only parser.parse_schema(handle, schema); std::vector column_indices; - //std::vector column_indices = {0, 6}; + // std::vector column_indices = {0, 6}; // std::vector column_indices = {0, 4}; // line item id fgloat // std::vector column_indices = {4}; // line item fgloat // std::vector column_indices = {8}; // line item ret_flag // std::vector column_indices = {1}; // nation 1 name - if (column_indices.empty()) { + if(column_indices.empty()) { size_t num_cols = schema.get_num_columns(); column_indices.resize(num_cols); std::iota(column_indices.begin(), column_indices.end(), 0); } mysql_provider->set_column_indices(column_indices); + //std::string exp = "BindableTableScan(table=[[main, lineitem]], filters=[[OR(AND(>($0, 599990), <=($3, 1998-09-02)), AND(<>(-($0, 1), +(65, /(*(*(98, $0), 2), 3))), IS NOT NULL($1)))]], projects=[[0, 1, 9, 10]], aliases=[[l_orderkey, l_partkey, l_linestatus, l_shipdate]])"; + //std::string exp = "BindableTableScan(table=[[main, orders]], filters=[[NOT(LIKE($2, '%special%requests%'))]], projects=[[0, 1, 8]], aliases=[[o_orderkey, o_custkey, o_comment]])"; + //std::string exp = "BindableTableScan(table=[[main, lineitem]], filters=[[AND(OR(=($4, 'MAIL'), =($4, 'SHIP')), <($2, $3), <($1, $2), >=($3, 1994-01-01), <($3, 1995-01-01))]], projects=[[0, 10, 11, 12, 14]], aliases=[[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode]])"; + std::string exp = "BindableTableScan(table=[[main, lineitem]], filters=[[AND(>=($3, 1995-09-01), <($3, 1995-10-01))]], projects=[[1, 5, 6, 10]], aliases=[[l_partkey, l_extendedprice, l_discount, l_shipdate]])"; + + + mysql_provider->set_predicate_pushdown(exp); + std::cout << "\tTABLE\n"; auto cols = schema.get_names(); std::cout << "total cols: " << cols.size() << "\n"; - for (int i = 0; i < cols.size(); ++i) { + for(int i = 0; i < cols.size(); ++i) { std::cout << "\ncol: " << schema.get_name(i) << "\n"; std::cout << "\ntyp: " << (int32_t) schema.get_dtype(i) << "\n"; } @@ -153,17 +167,17 @@ TEST_F(SQLProviderTest, DISABLED_mysql_select_all) { std::cout << "\n\nCUDFFFFFFFFFFFFFFFFFFFFFF\n"; bool only_once = false; - if (only_once) { + if(only_once) { std::cout << "\trows: " << rows << "\n"; handle = mysql_provider->get_next(); auto res = handle.sql_handle.mysql_resultset; - + bool has_next = mysql_provider->has_next(); std::cout << "\tNEXT?: " << (has_next ? "TRUE" : "FALSE") << "\n"; print_batch(handle, schema, parser, column_indices); } else { mysql_provider->reset(); - while (mysql_provider->has_next()) { + while(mysql_provider->has_next()) { handle = mysql_provider->get_next(); print_batch(handle, schema, parser, column_indices); } @@ -173,7 +187,7 @@ TEST_F(SQLProviderTest, DISABLED_mysql_select_all) { TEST_F(SQLProviderTest, DISABLED_sqlite_select_all) { ral::io::sql_info sql; sql.schema = "/blazingsql/db.sqlite3"; - sql.table = "prueba"; + sql.table = "prueba2"; sql.table_filter = ""; sql.table_batch_size = 2000; @@ -205,13 +219,13 @@ TEST_F(SQLProviderTest, DISABLED_sqlite_select_all) { std::cout << "SCHEMA" << std::endl << " length = " << schema.get_num_columns() << std::endl << " columns" << std::endl; - for (std::size_t i = 0; i < schema.get_num_columns(); i++) { - const std::string &name = schema.get_name(i); + for(std::size_t i = 0; i < schema.get_num_columns(); i++) { + const std::string & name = schema.get_name(i); std::cout << " " << name << ": "; try { const std::string dtypename = dt2name[schema.get_dtype(i)]; std::cout << dtypename << std::endl; - } catch (std::exception &) { + } catch(std::exception &) { std::cout << static_cast(schema.get_dtype(i)) << std::endl; } } @@ -227,4 +241,10 @@ TEST_F(SQLProviderTest, DISABLED_sqlite_select_all) { std::cout << "TABLE" << std::endl << " ncolumns = " << table->num_columns() << std::endl << " nrows = " << table->num_rows() << std::endl; + + auto tv = table->toBlazingTableView(); + + for(cudf::size_type i = 0; i < static_cast(num_cols); i++) { + cudf::test::print(tv.column(i)); + } } diff --git a/pyblazing/pyblazing/apiv2/context.py b/pyblazing/pyblazing/apiv2/context.py index 270deba26..39fe84bdb 100755 --- a/pyblazing/pyblazing/apiv2/context.py +++ b/pyblazing/pyblazing/apiv2/context.py @@ -922,7 +922,7 @@ def kwargs_validation(kwargs, bc_api_str): "port", "username", "password", - "schema", + "database", "table_filter", "table_batch_size", ] @@ -3142,7 +3142,7 @@ def sql( elif ( query_table.fileType == DataType.MYSQL or query_table.fileType == DataType.SQLITE - # or query_table.fileType == DataType. + or query_table.fileType == DataType.POSTGRESQL ): if query_table.has_metadata(): currentTableNodes = self._optimize_skip_data_getSlices( diff --git a/tests/BlazingSQLTest/DataBase/createSchema.py b/tests/BlazingSQLTest/DataBase/createSchema.py index 4f6e4fdb6..5695345c2 100644 --- a/tests/BlazingSQLTest/DataBase/createSchema.py +++ b/tests/BlazingSQLTest/DataBase/createSchema.py @@ -61,23 +61,56 @@ def __init__(self, **kwargs): def get_sql_connection(fileSchemaType: DataType): - sql_hostname = os.getenv("BLAZINGSQL_E2E_SQL_HOSTNAME", "") - if fileSchemaType in [DataType.MYSQL, DataType.POSTGRESQL]: - if not sql_hostname: return None + if fileSchemaType is DataType.SQLITE: + return get_sqlite_connection() - sql_port = int(os.getenv("BLAZINGSQL_E2E_SQL_PORT", 0)) - if fileSchemaType in [DataType.MYSQL, DataType.POSTGRESQL]: - if sql_port == 0: return None + if fileSchemaType is DataType.MYSQL: + return get_mysql_connection() - sql_username = os.getenv("BLAZINGSQL_E2E_SQL_USERNAME", "") - if fileSchemaType in [DataType.MYSQL, DataType.POSTGRESQL]: - if not sql_username: return None + if fileSchemaType is DataType.POSTGRESQL: + return get_postgresql_connection() - sql_password = os.getenv("BLAZINGSQL_E2E_SQL_PASSWORD", "") - if fileSchemaType in [DataType.MYSQL, DataType.POSTGRESQL]: - if not sql_password: return None + raise ValueError('Unsupported data type {fileSchemaType}') - sql_schema = os.getenv("BLAZINGSQL_E2E_SQL_SCHEMA", "") + +def get_mysql_connection() -> sql_connection: + sql_hostname = os.getenv("BLAZINGSQL_E2E_MYSQL_HOSTNAME", "") + if not sql_hostname: return None + + sql_port = int(os.getenv("BLAZINGSQL_E2E_MYSQL_PORT", 0)) + if sql_port == 0: return None + + sql_username = os.getenv("BLAZINGSQL_E2E_MYSQL_USERNAME", "") + if not sql_username: return None + + sql_password = os.getenv("BLAZINGSQL_E2E_MYSQL_PASSWORD", "") + if not sql_password: return None + + sql_schema = os.getenv("BLAZINGSQL_E2E_MYSQL_DATABASE", "") + if not sql_schema: return None + + ret = sql_connection(hostname = sql_hostname, + port = sql_port, + username = sql_username, + password = sql_password, + schema = sql_schema) + return ret + + +def get_postgresql_connection() -> sql_connection: + sql_hostname = os.getenv("BLAZINGSQL_E2E_POSTGRESQL_HOSTNAME", "") + if not sql_hostname: return None + + sql_port = int(os.getenv("BLAZINGSQL_E2E_POSTGRESQL_PORT", 0)) + if sql_port == 0: return None + + sql_username = os.getenv("BLAZINGSQL_E2E_POSTGRESQL_USERNAME", "") + if not sql_username: return None + + sql_password = os.getenv("BLAZINGSQL_E2E_POSTGRESQL_PASSWORD", "") + if not sql_password: return None + + sql_schema = os.getenv("BLAZINGSQL_E2E_POSTGRESQL_DATABASE", "") if not sql_schema: return None ret = sql_connection(hostname = sql_hostname, @@ -88,6 +121,16 @@ def get_sql_connection(fileSchemaType: DataType): return ret +def get_sqlite_connection() -> sql_connection: + sql_schema = os.getenv("BLAZINGSQL_E2E_SQLITE_DATABASE", "") + if not sql_schema: return None + return sql_connection(hostname='', + port=0, + username='', + password='', + schema=sql_schema) + + def getFiles_to_tmp(tpch_dir, n_files, ext): list_files = [] for name in tableNames: @@ -1319,7 +1362,7 @@ def create_tables(bc, dir_data_lc, fileSchemaType, **kwargs): port = sql_port, username = sql_username, password = sql_password, - schema = sql_schema, + database = sql_schema, table_filter = sql_table_filter, table_batch_size = sql_table_batch_size) else: diff --git a/tests/BlazingSQLTest/DataBase/mysqlSchema.py b/tests/BlazingSQLTest/DataBase/mysqlSchema.py index 75ab92d68..dd1724fa0 100644 --- a/tests/BlazingSQLTest/DataBase/mysqlSchema.py +++ b/tests/BlazingSQLTest/DataBase/mysqlSchema.py @@ -143,7 +143,7 @@ def create_and_load_tpch_schema(sql: sql_connection, only_create_tables : bool = for table, table_description in mysql_tpch_table_descriptions.items(): ok = create_mysql_table(table_description, cursor) if ok and not only_create_tables: - table_files = "%s/%s*.psv" % (tabs_dir, table) + table_files = "%s/%s_*.psv" % (tabs_dir, table) mysql_load_data_in_file(table, table_files, cursor, cnx) else: print("MySQL table %s already exists, will not load any data!" % table) diff --git a/tests/BlazingSQLTest/DataBase/postgreSQLSchema.py b/tests/BlazingSQLTest/DataBase/postgreSQLSchema.py new file mode 100644 index 000000000..0afbacc56 --- /dev/null +++ b/tests/BlazingSQLTest/DataBase/postgreSQLSchema.py @@ -0,0 +1,137 @@ +import os +import glob + +from .createSchema import get_sql_connection, get_column_names, sql_connection + +from blazingsql import DataType + +import psycopg2 + + +postgresql_tpch_table_descriptions = { + "nation": """create table nation ( n_nationkey integer, + n_name char(25), + n_regionkey integer, + n_comment varchar(152))""", + "region": """create table region ( r_regionkey integer, + r_name char(25), + r_comment varchar(152))""", + "part": """create table part ( p_partkey integer , + p_name varchar(55) , + p_mfgr char(25) , + p_brand char(10) , + p_type varchar(25) , + p_size integer , + p_container char(10) , + p_retailprice decimal(15,2) , + p_comment varchar(23) )""", + "supplier": """create table supplier ( s_suppkey integer , + s_name char(25) , + s_address varchar(40) , + s_nationkey integer , + s_phone char(15) , + s_acctbal decimal(15,2) , + s_comment varchar(101) )""", + "partsupp": """create table partsupp ( ps_partkey integer , + ps_suppkey integer , + ps_availqty integer , + ps_supplycost decimal(15,2) , + ps_comment varchar(199) );""", + "customer": """create table customer ( c_custkey integer , + c_name varchar(25) , + c_address varchar(40) , + c_nationkey integer , + c_phone char(15) , + c_acctbal decimal(15,2) , + c_mktsegment char(10) , + c_comment varchar(117) );""", + "orders": """create table orders ( o_orderkey integer , + o_custkey integer , + o_orderstatus char(1) , + o_totalprice decimal(15,2) , + o_orderdate date , + o_orderpriority char(15) , + o_clerk char(15) , + o_shippriority integer , + o_comment varchar(79) )""", + "lineitem": """create table lineitem ( l_orderkey integer , + l_partkey integer , + l_suppkey integer , + l_linenumber integer , + l_quantity decimal(15,2) , + l_extendedprice decimal(15,2) , + l_discount decimal(15,2) , + l_tax decimal(15,2) , + l_returnflag char(1) , + l_linestatus char(1) , + l_shipdate date , + l_commitdate date , + l_receiptdate date , + l_shipinstruct char(25) , + l_shipmode char(10) , + l_comment varchar(44) )""", +} + + +# if table already exists returns False +def create_postgresql_table(table_description: str, cursor) -> bool: + print("Creating table {}: ".format(table_description), end='') + cursor.execute(table_description) + return True + + +def copy(cursor, csvFile, tableName): + csvDelimiter = "'|'" + csvQuoteCharacter = "'\"'" + query = "COPY %s FROM STDIN WITH CSV QUOTE %s DELIMITER AS %s NULL as 'null'" % (tableName, csvQuoteCharacter, csvDelimiter) + cursor.copy_expert(sql = query, file = csvFile) + + +def postgresql_load_data_in_file(table: str, full_path_wildcard: str, cursor, cnx): + cols = get_column_names(table) + h = "" + b = "" + for i,c in enumerate(cols): + h = h + "@" + c + hj = "%s = NULLIF(@%s,'null')" % (c,c) + b = b + hj + if i + 1 != len(cols): + h = h + ",\n" + b = b + ",\n" + + a = glob.glob(full_path_wildcard) + for fi in a: + with open(fi, 'r') as csvFile: + copy(cursor, csvFile, table) + cnx.commit() + print("load data done!") + + +# using the nulls dataset +def create_and_load_tpch_schema(sql: sql_connection, only_create_tables : bool = False): + #allow_local_infile = True) + cnx = psycopg2.connect( + dbname=sql.schema, + user=sql.username, + host=sql.hostname, + port=int(sql.port), + password=sql.password + ) + cursor = cnx.cursor() + + conda_prefix = os.getenv("CONDA_PREFIX", "") + tabs_dir = conda_prefix + "/" + "blazingsql-testing-files/data/tpch-with-nulls/" + + for table, table_description in postgresql_tpch_table_descriptions.items(): + cursor.execute(f"DROP TABLE IF EXISTS {table}") + cnx.commit() + ok = create_postgresql_table(table_description, cursor) + cnx.commit() + if ok and not only_create_tables: + table_files = "%s/%s_*.psv" % (tabs_dir, table) + postgresql_load_data_in_file(table, table_files, cursor, cnx) + else: + print("MySQL table %s already exists, will not load any data!" % table) + + cursor.close() + cnx.close() diff --git a/tests/BlazingSQLTest/DataBase/sqliteSchema.py b/tests/BlazingSQLTest/DataBase/sqliteSchema.py new file mode 100644 index 000000000..ab59c1b7d --- /dev/null +++ b/tests/BlazingSQLTest/DataBase/sqliteSchema.py @@ -0,0 +1,138 @@ +import csv +import glob +import os +import sqlite3 + +from tempfile import NamedTemporaryFile + +from .createSchema import sql_connection + + +sqlite_tpch_table_descriptions = { + 'nation': '''create table nation ( n_nationkey integer, + n_name char(25), + n_regionkey integer, + n_comment varchar(152))''', + 'region': '''create table region ( r_regionkey integer, + r_name char(25), + r_comment varchar(152))''', + 'part': '''create table part ( p_partkey integer , + p_name varchar(55) , + p_mfgr char(25) , + p_brand char(10) , + p_type varchar(25) , + p_size integer , + p_container char(10) , + p_retailprice decimal(15,2) , + p_comment varchar(23) )''', + 'supplier': '''create table supplier ( s_suppkey integer , + s_name char(25) , + s_address varchar(40) , + s_nationkey integer , + s_phone char(15) , + s_acctbal decimal(15,2) , + s_comment varchar(101) )''', + 'partsupp': '''create table partsupp ( ps_partkey integer , + ps_suppkey integer , + ps_availqty integer , + ps_supplycost decimal(15,2) , + ps_comment varchar(199) );''', + 'customer': '''create table customer ( c_custkey integer , + c_name varchar(25) , + c_address varchar(40) , + c_nationkey integer , + c_phone char(15) , + c_acctbal decimal(15,2) , + c_mktsegment char(10) , + c_comment varchar(117) );''', + 'orders': '''create table orders ( o_orderkey integer , + o_custkey integer , + o_orderstatus char(1) , + o_totalprice decimal(15,2) , + o_orderdate date , + o_orderpriority char(15) , + o_clerk char(15) , + o_shippriority integer , + o_comment varchar(79) )''', + 'lineitem': '''create table lineitem ( l_orderkey integer , + l_partkey integer , + l_suppkey integer , + l_linenumber integer , + l_quantity decimal(15,2) , + l_extendedprice decimal(15,2) , + l_discount decimal(15,2) , + l_tax decimal(15,2) , + l_returnflag char(1) , + l_linestatus char(1) , + l_shipdate date , + l_commitdate date , + l_receiptdate date , + l_shipinstruct char(25) , + l_shipmode char(10) , + l_comment varchar(44) )''', +} + + +def create_sqlite_table(table_description: str, + cursor: sqlite3.Cursor) -> bool: + try: + print(f'Creating table {table_description}', end='') + cursor.execute(table_description) + except sqlite3.OperationalError as error: + if 'already exists' in str(error): + print('already exists.') + return False + else: + raise ValueError( + f'Error creating from\n{table_description}') from error + else: + print('OK') + return True + + +def sqlite_load_data_in_file(table: str, + full_path_wildcard: str, + cursor: sqlite3.Cursor, + connection: sqlite3.Connection): + psvpaths = glob.glob(full_path_wildcard) + for psvpath in psvpaths: + with open(psvpath) as psv: + reader = csv.reader(psv, delimiter='|') + row = next(reader) + row = [c if c.lower() != 'null' else None for c in row] + nfields = ','.join('?' * len(row)) + query = f'insert into {table} values ({nfields})' + cursor.execute(query, row) + for row in reader: + row = [c if c.lower() != 'null' else None for c in row] + cursor.execute(query, row) + connection.commit() + + +def create_and_load_tpch_schema(sql: sql_connection, + only_create_tables: bool = False): + schema = sql.schema + if not schema: + temporaryFile = NamedTemporaryFile(delete=False) + schema = temporaryFile.name + connection = sqlite3.connect(schema) + + cursor = connection.cursor() + + conda_prefix = os.environ.get('CONDA_PREFIX', '') + tabs_dir = os.path.join( + conda_prefix, + 'blazingsql-testing-files/data/tpch-with-nulls/') + + for table, table_description in sqlite_tpch_table_descriptions.items(): + ok = create_sqlite_table(table_description, cursor) + if ok and not only_create_tables: + table_files = '%s/%s_*.psv' % (tabs_dir, table) + sqlite_load_data_in_file(table, table_files, cursor, connection) + else: + print( + 'SQLite table %s already exists, will not load any data!' % + table) + + cursor.close() + connection.close() diff --git a/tests/BlazingSQLTest/EndToEndTests/allE2ETest.py b/tests/BlazingSQLTest/EndToEndTests/allE2ETest.py index 4151a396b..d71b3eefc 100644 --- a/tests/BlazingSQLTest/EndToEndTests/allE2ETest.py +++ b/tests/BlazingSQLTest/EndToEndTests/allE2ETest.py @@ -306,7 +306,7 @@ def main(): if testsWithNulls == "true": if Settings.execution_mode != ExecutionMode.GPUCI: if runAllTests or ("tablesFromSQL" in targetTestGroups): - tablesFromSQL.main(dask_client, drill, dir_data_file, bc, nRals) + tablesFromSQL.main(dask_client, drill, spark, dir_data_file, bc, nRals) # WARNING!!! This Test must be the last one to test ------------------------------------------------------------------------------------------------------------------------------------------- if runAllTests or ("configOptionsTest" in targetTestGroups): diff --git a/tests/BlazingSQLTest/EndToEndTests/tablesFromSQL.py b/tests/BlazingSQLTest/EndToEndTests/tablesFromSQL.py index c4e5ee748..bf5a18148 100644 --- a/tests/BlazingSQLTest/EndToEndTests/tablesFromSQL.py +++ b/tests/BlazingSQLTest/EndToEndTests/tablesFromSQL.py @@ -1,17 +1,64 @@ +from collections import OrderedDict + from blazingsql import DataType from DataBase import createSchema from Configuration import ExecutionMode from Configuration import Settings as Settings from Runner import runTest from Utils import gpuMemory, skip_test -from EndToEndTests.tpchQueries import get_tpch_query +from EndToEndTests import tpchQueries + + +class Sample: + def __init__(self, **kwargs): + self.sample_id = kwargs.get("id", "") + self.query = kwargs.get("query", "") + self.table_mapper = kwargs.get("table_mapper", Sample.default_table_mapper) + self.worder = kwargs.get("worder", 1) + self.use_percentage = kwargs.get("use_percentage", False) + self.acceptable_difference = kwargs.get("acceptable_difference", 0.01) + self.use_pyspark = kwargs.get("use_pyspark", False) # we use drill by default + + def default_table_mapper(query, tables = {}): + return query + + +# table_mapper if you want to apply the same global table_mapper for all the samples +def define_samples(sample_list: [Sample], table_mapper = None): + ret = OrderedDict() + i = 1 + for sample in sample_list: + if table_mapper: + sample.table_mapper = table_mapper # override with global table_mapper + istr = str(i) if i > 10 else "0"+str(i) + sampleId = sample.sample_id + if not sampleId: + sampleId = "TEST_" + istr + i = i + 1 + ret[sampleId] = sample + return ret + queryType = "TablesFromSQL" + +samples = define_samples([ + Sample(query = tpchQueries.query_templates["TEST_13"]), + Sample(query = tpchQueries.query_templates["TEST_07"]), + Sample(query = tpchQueries.query_templates["TEST_12"]), + Sample(query = tpchQueries.query_templates["TEST_04"]), + Sample(query = tpchQueries.query_templates["TEST_01"]), + Sample(query = "select * from {nation}", use_pyspark = True), + Sample(query = tpchQueries.query_templates["TEST_08"], use_pyspark = True), + Sample(query = """select c_custkey, c_nationkey, c_acctbal + from {customer} where c_custkey < 150 and c_nationkey = 5 + or c_custkey = 200 or c_nationkey >= 10 + or c_acctbal <= 500""") +], tpchQueries.map_tables) + data_types = [ - DataType.MYSQL, + #DataType.MYSQL, #DataType.POSTGRESQL, - #DataType.SQLITE, - # TODO percy c.gonzales + DataType.SQLITE ] tables = [ @@ -29,6 +76,7 @@ # "lineitem": "l_quantity < 24", } + # aprox. taken from parquet parts (tpch with nulls 2 parts) sql_table_batch_sizes = { "nation": 30, @@ -41,23 +89,6 @@ "partsupp": 40000, } -tpch_queries = [ - "TEST_13", - "TEST_07", - "TEST_12", - "TEST_04", - "TEST_01", -] - -# Parameter to indicate if its necessary to order -# the resulsets before compare them -worder = 1 -use_percentage = False -acceptable_difference = 0.01 - -# example: {csv: {tb1: tb1_csv, ...}, parquet: {tb1: tb1_parquet, ...}} -datasource_tables = dict((ds, dict((t, t+"_"+str(ds).split(".")[1]) for t in tables)) for ds in data_types) - def datasources(dask_client, nRals): for fileSchemaType in data_types: @@ -66,37 +97,44 @@ def datasources(dask_client, nRals): yield fileSchemaType -def samples(bc, dask_client, nRals, **kwargs): +def sample_items(bc, dask_client, nRals, **kwargs): + # example: {csv: {tb1: tb1_csv, ...}, parquet: {tb1: tb1_parquet, ...}} + dstables = dict((ds, dict((t, t+"_"+str(ds).split(".")[1]) for t in tables)) for ds in data_types) + init_tables = kwargs.get("init_tables", False) sql_table_filter_map = kwargs.get("sql_table_filter_map", {}) sql_table_batch_size_map = kwargs.get("sql_table_batch_size_map", {}) sql = kwargs.get("sql_connection", None) + dir_data_lc = kwargs.get("dir_data_lc", "") + for fileSchemaType in datasources(dask_client, nRals): - dstables = datasource_tables[fileSchemaType] + datasource_tables = dstables[fileSchemaType] if init_tables: print("Creating tables for", str(fileSchemaType)) - table_names=list(dstables.values()) - createSchema.create_tables(bc, "", fileSchemaType, - tables = tables, - table_names=table_names, - sql_table_filter_map = sql_table_filter_map, - sql_table_batch_size_map = sql_table_batch_size_map, - sql_connection = sql, - ) + table_names = list(datasource_tables.values()) + if isinstance(sql, createSchema.sql_connection): # create sql tables + createSchema.create_tables(bc, "", fileSchemaType, + tables = tables, + table_names=table_names, + sql_table_filter_map = sql_table_filter_map, + sql_table_batch_size_map = sql_table_batch_size_map, + sql_connection = sql, + ) + else: # create in file tables (parquet, csv, etc) + createSchema.create_tables(bc, dir_data_lc, fileSchemaType, + tables = tables, + table_names=table_names + ) + print("All tables were created for", str(fileSchemaType)) - i = 0 - queries = [get_tpch_query(q, dstables) for q in tpch_queries] - for query in queries: - i = i + 1 - istr = str(i) if i > 10 else "0"+str(i) - queryId = "TEST_" + istr - sampleId = str(fileSchemaType) + "." + queryId - yield sampleId, query, queryId, fileSchemaType + for sampleId, sample in samples.items(): + sampleUID = str(fileSchemaType) + "." + sampleId + yield sampleUID, sampleId, fileSchemaType, datasource_tables -def run_queries(bc, dask_client, nRals, drill, dir_data_lc, tables, **kwargs): +def run_queries(bc, dask_client, nRals, drill, spark, dir_data_lc, tables, **kwargs): sql_table_filter_map = kwargs.get("sql_table_filter_map", {}) sql_table_batch_size_map = kwargs.get("sql_table_batch_size_map", {}) sql = kwargs.get("sql_connection", None) @@ -108,61 +146,92 @@ def run_queries(bc, dask_client, nRals, drill, dir_data_lc, tables, **kwargs): "sql_table_filter_map": sql_table_filter_map, "sql_table_batch_size_map": sql_table_batch_size_map, "sql_connection": sql, + "dir_data_lc": dir_data_lc, } currrentFileSchemaType = data_types[0] - for sampleId, query, queryId, fileSchemaType in samples(bc, dask_client, nRals, **extra_args): + for sampleUID, sampleId, fileSchemaType, datasource_tables in sample_items(bc, dask_client, nRals, **extra_args): datasourceDone = (fileSchemaType != currrentFileSchemaType) if datasourceDone and Settings.execution_mode == ExecutionMode.GENERATOR: print("==============================") break_flag = True break + sample = samples[sampleId] + + query = sample.table_mapper(sample.query, datasource_tables) # map to tables with datasource info: order_csv, nation_csv ... + worder = sample.worder + use_percentage = sample.use_percentage + acceptable_difference = sample.acceptable_difference + use_pyspark = sample.use_pyspark + engine = spark if use_pyspark else drill + query_spark = sample.table_mapper(sample.query) # map to tables without datasource info: order, nation ... + print("==>> Run query for sample", sampleId) + print("PLAN:") + print(bc.explain(query, True)) runTest.run_query( bc, - drill, + engine, query, - queryId, + sampleId, queryType, worder, "", acceptable_difference, use_percentage, fileSchemaType, + query_spark = query_spark, print_result = True ) currrentFileSchemaType = fileSchemaType -def setup_test() -> bool: - sql = createSchema.get_sql_connection(DataType.MYSQL) +def setup_test(data_type: DataType) -> createSchema.sql_connection: + sql = createSchema.get_sql_connection(data_type) if not sql: - print("ERROR: You cannot run tablesFromSQL test, settup your SQL connection using env vars! See tests/README.md") + print(f"ERROR: You cannot run tablesFromSQL test, setup your SQL connection for {data_type}using env vars! See tests/README.md") return None - from DataBase import mysqlSchema + if data_type is DataType.MYSQL: + from DataBase import mysqlSchema + mysqlSchema.create_and_load_tpch_schema(sql) + return sql + + if data_type is DataType.SQLITE: + from DataBase import sqliteSchema + sqliteSchema.create_and_load_tpch_schema(sql) + return sql - mysqlSchema.create_and_load_tpch_schema(sql) - return sql + if data_type is DataType.POSTGRESQL: + from DataBase import postgreSQLSchema + postgreSQLSchema.create_and_load_tpch_schema(sql) + return sql -def executionTest(dask_client, drill, dir_data_lc, bc, nRals, sql): +def executionTest(dask_client, drill, spark, dir_data_lc, bc, nRals, sql): extra_args = { "sql_table_filter_map": sql_table_filters, "sql_table_batch_size_map": sql_table_batch_sizes, "sql_connection": sql, } - run_queries(bc, dask_client, nRals, drill, dir_data_lc, tables, **extra_args) + run_queries(bc, dask_client, nRals, drill, spark, dir_data_lc, tables, **extra_args) -def main(dask_client, drill, dir_data_lc, bc, nRals): +def main(dask_client, drill, spark, dir_data_lc, bc, nRals): print("==============================") print(queryType) print("==============================") - sql = setup_test() - if sql: - start_mem = gpuMemory.capture_gpu_memory_usage() - executionTest(dask_client, drill, dir_data_lc, bc, nRals, sql) - end_mem = gpuMemory.capture_gpu_memory_usage() - gpuMemory.log_memory_usage(queryType, start_mem, end_mem) + for data_type in data_types: + sql = None + is_file_ds = False + # we can change the datatype for these tests and it should works just fine + if data_type not in [DataType.MYSQL, DataType.POSTGRESQL, DataType.SQLITE]: + is_file_ds = True + else: + sql = setup_test(data_type) + if sql or is_file_ds: + start_mem = gpuMemory.capture_gpu_memory_usage() + executionTest(dask_client, drill, spark, dir_data_lc, bc, nRals, sql) + end_mem = gpuMemory.capture_gpu_memory_usage() + gpuMemory.log_memory_usage(queryType, start_mem, end_mem) diff --git a/tests/BlazingSQLTest/EndToEndTests/tpchQueries.py b/tests/BlazingSQLTest/EndToEndTests/tpchQueries.py index 159230564..00d3ac720 100644 --- a/tests/BlazingSQLTest/EndToEndTests/tpchQueries.py +++ b/tests/BlazingSQLTest/EndToEndTests/tpchQueries.py @@ -1,831 +1,838 @@ -def get_tpch_query(test_id, tables = {}): - nation = tables.get("nation", "nation") - region = tables.get("region", "region") - customer = tables.get("customer", "customer") - lineitem = tables.get("lineitem", "lineitem") - orders = tables.get("orders", "orders") - supplier = tables.get("supplier", "supplier") - part = tables.get("part", "part") - partsupp = tables.get("partsupp", "partsupp") - - queries = { - "TEST_01": f""" - select - l_returnflag, - l_linestatus, - sum(l_quantity) as sum_qty, - sum(l_extendedprice) as sum_base_price, - sum(l_extendedprice*(1-l_discount)) as sum_disc_price, - sum(l_extendedprice*(1-l_discount)*(1+l_tax)) - as sum_charge, - avg(l_quantity) as avg_qty, - avg(l_extendedprice) as avg_price, - avg(l_discount) as avg_disc, - count(*) as count_order - from - {lineitem} - where - l_shipdate <= date '1998-12-01' - interval '90' day - group by - l_returnflag, - l_linestatus - order by - l_returnflag, - l_linestatus - """, - - # Edited: - # - implicit joins generated some condition=[true] on Blazingsql - # - added table aliases to avoid ambiguity on Drill - - "TEST_02": f""" +query_templates = { + "TEST_01": """ + select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice*(1-l_discount)) as sum_disc_price, + sum(l_extendedprice*(1-l_discount)*(1+l_tax)) + as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from + {lineitem} + where + l_shipdate <= date '1998-12-01' - interval '90' day + group by + l_returnflag, + l_linestatus + order by + l_returnflag, + l_linestatus + """, + + # Edited: + # - implicit joins generated some condition=[true] on Blazingsql + # - added table aliases to avoid ambiguity on Drill + + "TEST_02": """ + select + s.s_acctbal, + s.s_name, + n.n_name, + p.p_partkey, + p.p_mfgr, + s.s_address, + s.s_phone, + s.s_comment + from + {supplier} as s + inner join {nation} as n on s.s_nationkey = n.n_nationkey + inner join {partsupp} as ps on s.s_suppkey = ps.ps_suppkey + inner join {part} as p on p.p_partkey = ps.ps_partkey + inner join {region} as r on r.r_regionkey = n.n_regionkey + where + p.p_size = 15 + and p.p_type like '%BRASS' + and r.r_name = 'EUROPE' + and ps.ps_supplycost = ( + select + min(psq.ps_supplycost) + from + {partsupp} as psq + inner join {supplier} sq on + sq.s_suppkey = psq.ps_suppkey + inner join {nation} as nq on + sq.s_nationkey = nq.n_nationkey + inner join {region} as rq on + nq.n_regionkey = rq.r_regionkey + where + p.p_partkey = psq.ps_partkey + and rq.r_name = 'EUROPE' + ) + order by + s.s_acctbal desc, + n.n_name, + s.s_name, + p.p_partkey + limit 100 + """, + + # Edited: + # - implicit joins without table aliases causes + # parsing errors on Drill + # - added table aliases to avoid ambiguity on Drill + # - There is an issue with validation on gpuci + + "TEST_03": """ + select + l.l_orderkey, + sum(l.l_extendedprice*(1-l.l_discount)) as revenue, + o.o_orderdate, + o.o_shippriority + from + {customer} c + inner join {orders} o + on c.c_custkey = o.o_custkey + inner join {lineitem} l + on l.l_orderkey = o.o_orderkey + where + c.c_mktsegment = 'BUILDING' + and o.o_orderdate < date '1995-03-15' + and l.l_shipdate > date '1995-03-15' + group by + l.l_orderkey, + o.o_orderdate, + o.o_shippriority + order by + l.l_orderkey, + revenue desc, + o.o_orderdate + limit 10 + """, + + # WARNING: + # - Became implicit joins into explicit joins + # - Fails with Drill, passes only with ORC files on PySpark + # - Passes with BigQuery + # - Blazingsql is returning different results + # for parquet, psv, and gdf + + "TEST_04": """ + select + o.o_orderpriority, + count(*) as order_count + from + {orders} o + where + o.o_orderdate >= date '1993-07-01' + and o.o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + {lineitem} l + where + l.l_orderkey = o.o_orderkey + and l.l_commitdate < l.l_receiptdate + ) + group by + o.o_orderpriority + order by + o.o_orderpriority + """, + + # Edited: + # - implicit joins without table aliases causes + # parsing errors on Drill + + "TEST_05": """ + select + n.n_name, + sum(l.l_extendedprice * (1 - l.l_discount)) as revenue + from + {customer} as c + inner join {orders} as o + on c.c_custkey = o.o_custkey + inner join {lineitem} as l + on l.l_orderkey = o.o_orderkey + inner join {supplier} as s + on l.l_suppkey = s.s_suppkey + inner join {nation} as n + on s.s_nationkey = n.n_nationkey + inner join {region} as r + on n.n_regionkey = r.r_regionkey + inner join {customer} c2 + on c2.c_nationkey = s.s_nationkey + where + r.r_name = 'ASIA' + and o.o_orderdate >= date '1994-01-01' + and o.o_orderdate < date '1995-01-01' + group by + n.n_name, + o.o_orderkey, + l.l_linenumber + order by + revenue desc + """, + + # Edited: + # - Became implicit joins into explicit joins + # - Added o.o_orderkey, l.l_linenumber into group by clause + # - Changed ('1994-01-01' + interval '1' year) by date '1995-01-01' + # to became the query deterministic. + # - Even that there is a difference with evaluations on Calcite, + # the query passes on PySpark and BigQuery but fails on Drill + # >=($6, -(0.06:DECIMAL(3, 2), 0.01:DECIMAL(3, 2))), + # <=($6, +(0.06:DECIMAL(3, 2), 0.01:DECIMAL(3, 2))) became + # >=($2, 0.05:DECIMAL(4, 2)), <=($2, 0.07:DECIMAL(4, 2)) + + "TEST_06": """ + select + sum(l_extendedprice*l_discount) as revenue + from + {lineitem} + where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24 + """, + + # Edited: + # - implicit joins without table aliases causes + # parsing errors on Drill + + "TEST_07": """ + select + supp_nation, + cust_nation, + l_year, sum(volume) as revenue + from ( select - s.s_acctbal, - s.s_name, - n.n_name, - p.p_partkey, - p.p_mfgr, - s.s_address, - s.s_phone, - s.s_comment + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l.l_shipdate) as l_year, + l.l_extendedprice * (1 - l.l_discount) as volume from {supplier} as s - inner join {nation} as n on s.s_nationkey = n.n_nationkey - inner join {partsupp} as ps on s.s_suppkey = ps.ps_suppkey - inner join {part} as p on p.p_partkey = ps.ps_partkey - inner join {region} as r on r.r_regionkey = n.n_regionkey + inner join {lineitem} as l + on s.s_suppkey = l.l_suppkey + inner join {orders} as o + on o.o_orderkey = l.l_orderkey + inner join {customer} as c + on c.c_custkey = o.o_custkey + inner join {nation} as n1 + on s.s_nationkey = n1.n_nationkey + inner join {nation} as n2 + on c.c_nationkey = n2.n_nationkey where - p.p_size = 15 - and p.p_type like '%BRASS' - and r.r_name = 'EUROPE' - and ps.ps_supplycost = ( - select - min(psq.ps_supplycost) - from - {partsupp} as psq - inner join {supplier} sq on - sq.s_suppkey = psq.ps_suppkey - inner join {nation} as nq on - sq.s_nationkey = nq.n_nationkey - inner join {region} as rq on - nq.n_regionkey = rq.r_regionkey - where - p.p_partkey = psq.ps_partkey - and rq.r_name = 'EUROPE' - ) - order by - s.s_acctbal desc, - n.n_name, - s.s_name, - p.p_partkey - limit 100 - """, - - # Edited: - # - implicit joins without table aliases causes - # parsing errors on Drill - # - added table aliases to avoid ambiguity on Drill - # - There is an issue with validation on gpuci - - "TEST_03": f""" + ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l.l_shipdate between date '1995-01-01' and + date '1996-12-31' + ) as shipping + group by + supp_nation, + cust_nation, + l_year + order by + supp_nation, + cust_nation, + l_year + """, + + # Edited: + # - implicit joins generated some condition=[true] on Blazingsql + # - 'nation' colum name was renamed to nationl because it produces + # a parse error on Drill + # - added table aliases to avoid ambiguity on Drill + + "TEST_08": """ + select + o_year, + sum(case + when nationl = 'BRAZIL' + then volume + else 0 + end) / sum(volume) as mkt_share + from ( select - l.l_orderkey, - sum(l.l_extendedprice*(1-l.l_discount)) as revenue, - o.o_orderdate, - o.o_shippriority + extract(year from o.o_orderdate) as o_year, + l.l_extendedprice * (1-l.l_discount) as volume, + n2.n_name as nationl from - {customer} c - inner join {orders} o - on c.c_custkey = o.o_custkey - inner join {lineitem} l - on l.l_orderkey = o.o_orderkey + {part} as p + inner join {lineitem} as l on p.p_partkey = l.l_partkey + inner join {supplier} as s on s.s_suppkey = l.l_suppkey + inner join {orders} as o on o.o_orderkey = l.l_orderkey + inner join {customer} as c on c.c_custkey = o.o_custkey + inner join {nation} as n1 on + n1.n_nationkey = c.c_nationkey + inner join {nation} as n2 on + n2.n_nationkey = s.s_nationkey + inner join {region} as r on + r.r_regionkey = n1.n_regionkey where - c.c_mktsegment = 'BUILDING' - and o.o_orderdate < date '1995-03-15' - and l.l_shipdate > date '1995-03-15' - group by - l.l_orderkey, - o.o_orderdate, - o.o_shippriority + r.r_name = 'AMERICA' + and o.o_orderdate between date '1995-01-01' and + date '1996-12-31' + and p.p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations + group by + o_year order by - l.l_orderkey, - revenue desc, - o.o_orderdate - limit 10 - """, - - # WARNING: - # - Became implicit joins into explicit joins - # - Fails with Drill, passes only with ORC files on PySpark - # - Passes with BigQuery - # - Blazingsql is returning different results - # for parquet, psv, and gdf - - "TEST_04": f""" + o_year + """, + + # Edited: + # - implicit joins generated some condition=[true] on Blazingsql + # - 'nation' colum name was renamed to nationl because it + # produces a parse error on Drill + # - implicit joins without table aliases causes parsing + # errors on Drill + + "TEST_09": """ + select + nationl, + o_year, + sum(amount) as sum_profit + from ( select - o.o_orderpriority, - count(*) as order_count + n.n_name as nationl, + extract(year from o.o_orderdate) as o_year, + l.l_extendedprice * (1 - l.l_discount) - + ps.ps_supplycost * l.l_quantity as amount from - {orders} o + {lineitem} as l + inner join {orders} as o + on o.o_orderkey = l.l_orderkey + inner join {partsupp} as ps + on ps.ps_suppkey = l.l_suppkey + inner join {part} as p + on p.p_partkey = l.l_partkey + inner join {supplier} as s + on s.s_suppkey = l.l_suppkey + inner join {nation} as n + on n.n_nationkey = s.s_nationkey where - o.o_orderdate >= date '1993-07-01' - and o.o_orderdate < date '1993-07-01' + interval '3' month - and exists ( + l.l_partkey = ps.ps_partkey + and p.p_name like '%green%' + ) as profit + group by + nationl, + o_year + order by + nationl, + o_year desc + """, + + # Edited: + # - implicit joins without table aliases causes parsing + # errors on Drill + # - no needed to converting to explicit joins, added + # only table aliases + # - order by c.c_custkey, with null data it is necessary + # for it to match with drill or spark, because there is a "Limit" + + "TEST_10": """ + select + c.c_custkey, + c.c_name, + sum(l.l_extendedprice * (1 - l.l_discount)) as revenue, + c.c_acctbal, + n.n_name, + c.c_address, + c.c_phone, + c.c_comment + from + {customer} c + inner join {orders} o + on c.c_custkey = o.o_custkey + inner join {lineitem} l + on l.l_orderkey = o.o_orderkey + inner join {nation} n + on c.c_nationkey = n.n_nationkey + where + o.o_orderdate >= date '1993-10-01' + and o.o_orderdate < date '1993-10-01' + interval '3' month + and l.l_returnflag = 'R' + group by + c.c_custkey, + c.c_name, + c.c_acctbal, + c.c_phone, + n.n_name, + c.c_address, + c.c_comment + order by + revenue desc, + c.c_custkey + limit 20 + """, + + # Edited: + # - 'value' colum name was renamed to valuep because it produces + # a parse error on Drill + # WARNING: Join condition is currently not supported + + "TEST_11": """ + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as valuep + from + {partsupp}, + {supplier}, + {nation} + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( select - * + sum(ps_supplycost * ps_availqty) * 0.0001 from - {lineitem} l + {partsupp}, + {supplier}, + {nation} where - l.l_orderkey = o.o_orderkey - and l.l_commitdate < l.l_receiptdate - ) - group by - o.o_orderpriority - order by - o.o_orderpriority - """, - - # Edited: - # - implicit joins without table aliases causes - # parsing errors on Drill - - "TEST_05": f""" + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) + order by + valuep desc + """, + + # Edited: + # - implicit joins without table aliases causes parsing + # errors on Drill + # - no needed to converting to explicit joins, added + # only table aliases + + "TEST_12": """ + select + l.l_shipmode, + sum(case + when o.o_orderpriority ='1-URGENT' + or o.o_orderpriority ='2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o.o_orderpriority <> '1-URGENT' + and o.o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count + from + {orders} o + inner join {lineitem} l + on o.o_orderkey = l.l_orderkey + where + l.l_shipmode in ('MAIL', 'SHIP') + and l.l_commitdate < l.l_receiptdate + and l.l_shipdate < l.l_commitdate + and l.l_receiptdate >= date '1994-01-01' + and l.l_receiptdate < date '1994-01-01' + + interval '1' year + group by + l.l_shipmode + order by + l.l_shipmode + """, + + # Edited: + # - added table aliases to avoid ambiguity on Drill + + "TEST_13": """ + select + c_count, count(*) as custdist + from ( select - n.n_name, - sum(l.l_extendedprice * (1 - l.l_discount)) as revenue + c.c_custkey, + count(o.o_orderkey) from - {customer} as c - inner join {orders} as o - on c.c_custkey = o.o_custkey - inner join {lineitem} as l - on l.l_orderkey = o.o_orderkey - inner join {supplier} as s - on l.l_suppkey = s.s_suppkey - inner join {nation} as n - on s.s_nationkey = n.n_nationkey - inner join {region} as r - on n.n_regionkey = r.r_regionkey - inner join {customer} c2 - on c2.c_nationkey = s.s_nationkey - where - r.r_name = 'ASIA' - and o.o_orderdate >= date '1994-01-01' - and o.o_orderdate < date '1995-01-01' + {customer} c left outer join {orders} o on + c.c_custkey = o.o_custkey + and o.o_comment not like '%special%requests%' group by - n.n_name, - o.o_orderkey, - l.l_linenumber - order by - revenue desc - """, - - # Edited: - # - Became implicit joins into explicit joins - # - Added o.o_orderkey, l.l_linenumber into group by clause - # - Changed ('1994-01-01' + interval '1' year) by date '1995-01-01' - # to became the query deterministic. - # - Even that there is a difference with evaluations on Calcite, - # the query passes on PySpark and BigQuery but fails on Drill - # >=($6, -(0.06:DECIMAL(3, 2), 0.01:DECIMAL(3, 2))), - # <=($6, +(0.06:DECIMAL(3, 2), 0.01:DECIMAL(3, 2))) became - # >=($2, 0.05:DECIMAL(4, 2)), <=($2, 0.07:DECIMAL(4, 2)) - - "TEST_06": f""" + c.c_custkey + )as c_orders (c_custkey, c_count) + group by + c_count + order by + custdist desc, + c_count desc + """, + + # Edited: + # - implicit joins without table aliases causes parsing + # errors on Drill + # - no needed to converting to explicit joins, added + # only table aliases + + "TEST_14": """ + select + 100.00 * sum(case + when p.p_type like 'PROMO%' + then l.l_extendedprice*(1-l.l_discount) + else 0 + end) / sum(l.l_extendedprice * (1 - l.l_discount)) + as promo_revenue + from + {lineitem} l + inner join {part} p + on l.l_partkey = p.p_partkey + where + l.l_shipdate >= date '1995-09-01' + and l.l_shipdate < date '1995-09-01' + interval '1' month + """, + + "TEST_15": """ + with revenue (suplier_no, total_revenue) as ( select - sum(l_extendedprice*l_discount) as revenue + l_suppkey, + cast(sum(l_extendedprice * (1-l_discount)) AS INTEGER) from {lineitem} where - l_shipdate >= date '1994-01-01' - and l_shipdate < date '1994-01-01' + interval '1' year - and l_discount between 0.06 - 0.01 and 0.06 + 0.01 - and l_quantity < 24 - """, - - # Edited: - # - implicit joins without table aliases causes - # parsing errors on Drill - - "TEST_07": f""" - select - supp_nation, - cust_nation, - l_year, sum(volume) as revenue - from ( - select - n1.n_name as supp_nation, - n2.n_name as cust_nation, - extract(year from l.l_shipdate) as l_year, - l.l_extendedprice * (1 - l.l_discount) as volume - from - {supplier} as s - inner join {lineitem} as l - on s.s_suppkey = l.l_suppkey - inner join {orders} as o - on o.o_orderkey = l.l_orderkey - inner join {customer} as c - on c.c_custkey = o.o_custkey - inner join {nation} as n1 - on s.s_nationkey = n1.n_nationkey - inner join {nation} as n2 - on c.c_nationkey = n2.n_nationkey - where - ( - (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') - or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') - ) - and l.l_shipdate between date '1995-01-01' and - date '1996-12-31' - ) as shipping + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month group by - supp_nation, - cust_nation, - l_year - order by - supp_nation, - cust_nation, - l_year - """, - - # Edited: - # - implicit joins generated some condition=[true] on Blazingsql - # - 'nation' colum name was renamed to nationl because it produces - # a parse error on Drill - # - added table aliases to avoid ambiguity on Drill - - "TEST_08": f""" - select - o_year, - sum(case - when nationl = 'BRAZIL' - then volume - else 0 - end) / sum(volume) as mkt_share - from ( + l_suppkey + ) + select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue + from + {supplier} + inner join revenue + on s_suppkey = suplier_no + where + total_revenue = ( select - extract(year from o.o_orderdate) as o_year, - l.l_extendedprice * (1-l.l_discount) as volume, - n2.n_name as nationl + max(total_revenue) from - {part} as p - inner join {lineitem} as l on p.p_partkey = l.l_partkey - inner join {supplier} as s on s.s_suppkey = l.l_suppkey - inner join {orders} as o on o.o_orderkey = l.l_orderkey - inner join {customer} as c on c.c_custkey = o.o_custkey - inner join {nation} as n1 on - n1.n_nationkey = c.c_nationkey - inner join {nation} as n2 on - n2.n_nationkey = s.s_nationkey - inner join {region} as r on - r.r_regionkey = n1.n_regionkey - where - r.r_name = 'AMERICA' - and o.o_orderdate between date '1995-01-01' and - date '1996-12-31' - and p.p_type = 'ECONOMY ANODIZED STEEL' - ) as all_nations - group by - o_year - order by - o_year - """, - - # Edited: - # - implicit joins generated some condition=[true] on Blazingsql - # - 'nation' colum name was renamed to nationl because it - # produces a parse error on Drill - # - implicit joins without table aliases causes parsing - # errors on Drill - - "TEST_09": f""" - select - nationl, - o_year, - sum(amount) as sum_profit - from ( + revenue + ) + order by + s_suppkey + """, + + # Edited: + # - Replacing 'create view' by 'with' clause + # - Pyspark doest not support this syntax as is + # WARNING: Drill presents undeterministic results + + "TEST_16": """ + select + p.p_brand, + p.p_type, + p.p_size, + count(distinct ps.ps_suppkey) as supplier_cnt + from + {partsupp} ps + inner join {part} p on p.p_partkey = ps.ps_partkey + where + p.p_brand <> 'Brand#45' + and p.p_type not like 'MEDIUM POLISHED%' + and p.p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps.ps_suppkey not in ( select - n.n_name as nationl, - extract(year from o.o_orderdate) as o_year, - l.l_extendedprice * (1 - l.l_discount) - - ps.ps_supplycost * l.l_quantity as amount + s_suppkey from - {lineitem} as l - inner join {orders} as o - on o.o_orderkey = l.l_orderkey - inner join {partsupp} as ps - on ps.ps_suppkey = l.l_suppkey - inner join {part} as p - on p.p_partkey = l.l_partkey - inner join {supplier} as s - on s.s_suppkey = l.l_suppkey - inner join {nation} as n - on n.n_nationkey = s.s_nationkey + {supplier} where - l.l_partkey = ps.ps_partkey - and p.p_name like '%green%' - ) as profit - group by - nationl, - o_year - order by - nationl, - o_year desc - """, - - # Edited: - # - implicit joins without table aliases causes parsing - # errors on Drill - # - no needed to converting to explicit joins, added - # only table aliases - # - order by c.c_custkey, with null data it is necessary - # for it to match with drill or spark, because there is a "Limit" - - "TEST_10": f""" - select - c.c_custkey, - c.c_name, - sum(l.l_extendedprice * (1 - l.l_discount)) as revenue, - c.c_acctbal, - n.n_name, - c.c_address, - c.c_phone, - c.c_comment - from - {customer} c - inner join {orders} o - on c.c_custkey = o.o_custkey - inner join {lineitem} l - on l.l_orderkey = o.o_orderkey - inner join {nation} n - on c.c_nationkey = n.n_nationkey - where - o.o_orderdate >= date '1993-10-01' - and o.o_orderdate < date '1993-10-01' + interval '3' month - and l.l_returnflag = 'R' - group by - c.c_custkey, - c.c_name, - c.c_acctbal, - c.c_phone, - n.n_name, - c.c_address, - c.c_comment - order by - revenue desc, - c.c_custkey - limit 20 - """, - - # Edited: - # - 'value' colum name was renamed to valuep because it produces - # a parse error on Drill - # WARNING: Join condition is currently not supported - - "TEST_11": f""" - select - ps_partkey, - sum(ps_supplycost * ps_availqty) as valuep - from - {partsupp}, - {supplier}, - {nation} - where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' - group by - ps_partkey having - sum(ps_supplycost * ps_availqty) > ( - select - sum(ps_supplycost * ps_availqty) * 0.0001 - from - {partsupp}, - {supplier}, - {nation} - where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' - ) - order by - valuep desc - """, - - # Edited: - # - implicit joins without table aliases causes parsing - # errors on Drill - # - no needed to converting to explicit joins, added - # only table aliases - - "TEST_12": f""" - select - l.l_shipmode, - sum(case - when o.o_orderpriority ='1-URGENT' - or o.o_orderpriority ='2-HIGH' - then 1 - else 0 - end) as high_line_count, - sum(case - when o.o_orderpriority <> '1-URGENT' - and o.o_orderpriority <> '2-HIGH' - then 1 - else 0 - end) as low_line_count - from - {orders} o - inner join {lineitem} l - on o.o_orderkey = l.l_orderkey - where - l.l_shipmode in ('MAIL', 'SHIP') - and l.l_commitdate < l.l_receiptdate - and l.l_shipdate < l.l_commitdate - and l.l_receiptdate >= date '1994-01-01' - and l.l_receiptdate < date '1994-01-01' + - interval '1' year - group by - l.l_shipmode - order by - l.l_shipmode - """, - - # Edited: - # - added table aliases to avoid ambiguity on Drill - - "TEST_13": f""" - select - c_count, count(*) as custdist - from ( + s_comment like '%Customer%Complaints%' + ) + group by + p.p_brand, + p.p_type, + p.p_size + order by + supplier_cnt desc, + p.p_brand, + p.p_type, + p.p_size + """, + + # Edited: + # - Became implicit joins into explicit joins + # - implicit joins generated some condition=[true] on Blazingsql + # - added table aliases to avoid ambiguity on Drill + + "TEST_17": """ + select + sum(l.l_extendedprice) / 7.0 as avg_yearly + from + {lineitem} l + inner join {part} p + on p.p_partkey = l.l_partkey + where + p.p_brand = 'Brand#23' + and p.p_container = 'MED BOX' + and l.l_quantity < ( select - c.c_custkey, - count(o.o_orderkey) + 0.2 * avg(l_quantity) from - {customer} c left outer join {orders} o on - c.c_custkey = o.o_custkey - and o.o_comment not like '%special%requests%' - group by - c.c_custkey - )as c_orders (c_custkey, c_count) - group by - c_count - order by - custdist desc, - c_count desc - """, - - # Edited: - # - implicit joins without table aliases causes parsing - # errors on Drill - # - no needed to converting to explicit joins, added - # only table aliases - - "TEST_14": f""" - select - 100.00 * sum(case - when p.p_type like 'PROMO%' - then l.l_extendedprice*(1-l.l_discount) - else 0 - end) / sum(l.l_extendedprice * (1 - l.l_discount)) - as promo_revenue - from - {lineitem} l - inner join {part} p - on l.l_partkey = p.p_partkey - where - l.l_shipdate >= date '1995-09-01' - and l.l_shipdate < date '1995-09-01' + interval '1' month - """, - - "TEST_15": f""" - with revenue (suplier_no, total_revenue) as ( + {lineitem} + where + l_partkey = p_partkey) + """, + + # Edited: + # - became implicit joins into explicit joins + # - no needed to converting to explicit joins, added + # only table aliases + # - this query fails on Drill with all format files, + # but passes on PySpark + + "TEST_18": """ + select + c.c_name, + c.c_custkey, + o.o_orderkey, + o.o_orderdate, + o.o_totalprice, + sum(l.l_quantity) + from + {customer} c + inner join {orders} o + on c.c_custkey = o.o_custkey + inner join {lineitem} l + on o.o_orderkey = l.l_orderkey + where + o.o_orderkey in ( select - l_suppkey, - cast(sum(l_extendedprice * (1-l_discount)) AS INTEGER) + l_orderkey from {lineitem} - where - l_shipdate >= date '1996-01-01' - and l_shipdate < date '1996-01-01' + interval '3' month group by - l_suppkey + l_orderkey having + sum(l_quantity) > 300 ) - select - s_suppkey, - s_name, - s_address, - s_phone, - total_revenue - from - {supplier} - inner join revenue - on s_suppkey = suplier_no - where - total_revenue = ( - select - max(total_revenue) - from - revenue - ) - order by - s_suppkey - """, - - # Edited: - # - Replacing 'create view' by 'with' clause - # - Pyspark doest not support this syntax as is - # WARNING: Drill presents undeterministic results - - "TEST_16": f""" - select - p.p_brand, - p.p_type, - p.p_size, - count(distinct ps.ps_suppkey) as supplier_cnt - from - {partsupp} ps - inner join {part} p on p.p_partkey = ps.ps_partkey - where - p.p_brand <> 'Brand#45' - and p.p_type not like 'MEDIUM POLISHED%' - and p.p_size in (49, 14, 23, 45, 19, 3, 36, 9) - and ps.ps_suppkey not in ( - select - s_suppkey - from - {supplier} - where - s_comment like '%Customer%Complaints%' - ) - group by - p.p_brand, - p.p_type, - p.p_size - order by - supplier_cnt desc, - p.p_brand, - p.p_type, - p.p_size - """, - - # Edited: - # - Became implicit joins into explicit joins - # - implicit joins generated some condition=[true] on Blazingsql - # - added table aliases to avoid ambiguity on Drill - - "TEST_17": f""" - select - sum(l.l_extendedprice) / 7.0 as avg_yearly - from - {lineitem} l - inner join {part} p - on p.p_partkey = l.l_partkey - where + group by + c.c_name, + c.c_custkey, + o.o_orderkey, + o.o_orderdate, + o.o_totalprice + order by + o.o_totalprice desc, + o.o_orderdate, + o.o_orderkey, + c.c_custkey + limit 100 + """, + + # Edited: + # - became implicit joins into explicit joins + # - implicit joins without table aliases causes parsing + # errors on Drill + # - no needed to converting to explicit joins, added only + # table aliases + + "TEST_19": """ + select + sum(l.l_extendedprice * (1 - l.l_discount) ) as revenue + from + {lineitem} l + inner join {part} p + ON l.l_partkey = p.p_partkey + where + ( + p.p_brand = 'Brand#12' + and p.p_container in ('SM CASE', 'SM BOX', + 'SM PACK', 'SM PKG') + and l.l_quantity >= 1 and l.l_quantity <= 1 + 10 + and p.p_size between 1 and 5 + and l.l_shipmode in ('AIR', 'AIR REG') + and l.l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( p.p_brand = 'Brand#23' - and p.p_container = 'MED BOX' - and l.l_quantity < ( - select - 0.2 * avg(l_quantity) - from - {lineitem} - where - l_partkey = p_partkey) - """, - - # Edited: - # - became implicit joins into explicit joins - # - no needed to converting to explicit joins, added - # only table aliases - # - this query fails on Drill with all format files, - # but passes on PySpark - - "TEST_18": f""" - select - c.c_name, - c.c_custkey, - o.o_orderkey, - o.o_orderdate, - o.o_totalprice, - sum(l.l_quantity) - from - {customer} c - inner join {orders} o - on c.c_custkey = o.o_custkey - inner join {lineitem} l - on o.o_orderkey = l.l_orderkey - where - o.o_orderkey in ( - select - l_orderkey - from - {lineitem} - group by - l_orderkey having - sum(l_quantity) > 300 - ) - group by - c.c_name, - c.c_custkey, - o.o_orderkey, - o.o_orderdate, - o.o_totalprice - order by - o.o_totalprice desc, - o.o_orderdate, - o.o_orderkey, - c.c_custkey - limit 100 - """, - - # Edited: - # - became implicit joins into explicit joins - # - implicit joins without table aliases causes parsing - # errors on Drill - # - no needed to converting to explicit joins, added only - # table aliases - - "TEST_19": f""" - select - sum(l.l_extendedprice * (1 - l.l_discount) ) as revenue - from - {lineitem} l - inner join {part} p - ON l.l_partkey = p.p_partkey - where - ( - p.p_brand = 'Brand#12' - and p.p_container in ('SM CASE', 'SM BOX', - 'SM PACK', 'SM PKG') - and l.l_quantity >= 1 and l.l_quantity <= 1 + 10 - and p.p_size between 1 and 5 - and l.l_shipmode in ('AIR', 'AIR REG') - and l.l_shipinstruct = 'DELIVER IN PERSON' - ) - or - ( - p.p_brand = 'Brand#23' - and p.p_container in ('MED BAG', 'MED BOX', - 'MED PKG', 'MED PACK') - and l.l_quantity >= 10 and l.l_quantity <= 10 + 10 - and p.p_size between 1 and 10 - and l.l_shipmode in ('AIR', 'AIR REG') - and l.l_shipinstruct = 'DELIVER IN PERSON' - ) - or - ( - p.p_brand = 'Brand#34' - and p.p_container in ('LG CASE', 'LG BOX', - 'LG PACK', 'LG PKG') - and l.l_quantity >= 20 and l.l_quantity <= 20 + 10 - and p.p_size between 1 and 15 - and l.l_shipmode in ('AIR', 'AIR REG') - and l.l_shipinstruct = 'DELIVER IN PERSON' - ) - """, - - # Edited: - # - implicit joins on Blazingsql generates 'Join condition is - # currently not supported' with - # LogicalJoin(condition=[OR( - # AND(=($10, $0), $11, $12, $2, $3, $13, $14, $4, $5), - # AND(=($10, $0), $15, $16, $6, $7, $13, $17, $4, $5), - # AND(=($10, $0), $18, $19, $8, $9, $13, $20, $4, $5))], - # joinType=[inner]) - # also parsing errors on Drill - # - added table aliases to avoid ambiguity on Drill - - "TEST_20": f""" - select - s.s_name, - s.s_address - from - {supplier} s - inner join {nation} n - on s.s_nationkey = n.n_nationkey - where - s.s_suppkey in ( - select - ps_suppkey - from - {partsupp} - where - ps_partkey in ( - select - p_partkey - from - {part} - where - p_name like 'forest%' - ) - and ps_availqty > ( + and p.p_container in ('MED BAG', 'MED BOX', + 'MED PKG', 'MED PACK') + and l.l_quantity >= 10 and l.l_quantity <= 10 + 10 + and p.p_size between 1 and 10 + and l.l_shipmode in ('AIR', 'AIR REG') + and l.l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p.p_brand = 'Brand#34' + and p.p_container in ('LG CASE', 'LG BOX', + 'LG PACK', 'LG PKG') + and l.l_quantity >= 20 and l.l_quantity <= 20 + 10 + and p.p_size between 1 and 15 + and l.l_shipmode in ('AIR', 'AIR REG') + and l.l_shipinstruct = 'DELIVER IN PERSON' + ) + """, + + # Edited: + # - implicit joins on Blazingsql generates 'Join condition is + # currently not supported' with + # LogicalJoin(condition=[OR( + # AND(=($10, $0), $11, $12, $2, $3, $13, $14, $4, $5), + # AND(=($10, $0), $15, $16, $6, $7, $13, $17, $4, $5), + # AND(=($10, $0), $18, $19, $8, $9, $13, $20, $4, $5))], + # joinType=[inner]) + # also parsing errors on Drill + # - added table aliases to avoid ambiguity on Drill + + "TEST_20": """ + select + s.s_name, + s.s_address + from + {supplier} s + inner join {nation} n + on s.s_nationkey = n.n_nationkey + where + s.s_suppkey in ( + select + ps_suppkey + from + {partsupp} + where + ps_partkey in ( select - 0.5 * sum(l_quantity) + p_partkey from - {lineitem} + {part} where - l_partkey = ps_partkey - and l_suppkey = ps_suppkey - and l_shipdate >= date '1994-01-01' - and l_shipdate < - date '1994-01-01' + interval '1' year + p_name like 'forest%' ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + {lineitem} + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < + date '1994-01-01' + interval '1' year ) - and n.n_name = 'CANADA' - order by - s.s_name - """, - - # Edited: - # - implicit joins without table aliases causes parsing errors - # on Drill no needed to converting to explicit joins, added - # only table aliases - # - this query fails on Drill with all format files, but passes - # on PySpark - - "TEST_21": f""" + ) + and n.n_name = 'CANADA' + order by + s.s_name + """, + + # Edited: + # - implicit joins without table aliases causes parsing errors + # on Drill no needed to converting to explicit joins, added + # only table aliases + # - this query fails on Drill with all format files, but passes + # on PySpark + + "TEST_21": """ + select + s_name, + count(*) as numwait + from + {supplier} + inner join {lineitem} l1 on s_suppkey = l1.l_suppkey + inner join {orders} on o_orderkey = l1.l_orderkey + inner join {nation} on n_nationkey = s_nationkey + where + o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + {lineitem} l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + {lineitem} l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and n_name = 'SAUDI ARABIA' + group by + s_name + order by + numwait desc, + s_name + limit 100 + """, + + # Edited: + # - implicit joins generated some condition=[true] on Blazingsql + # - by comparing with Pyspark there is no needed + # of adding more table aliases + + "TEST_22": """ + select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal + from ( select - s_name, - count(*) as numwait + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal from - {supplier} - inner join {lineitem} l1 on s_suppkey = l1.l_suppkey - inner join {orders} on o_orderkey = l1.l_orderkey - inner join {nation} on n_nationkey = s_nationkey + {customer} where - o_orderstatus = 'F' - and l1.l_receiptdate > l1.l_commitdate - and exists ( + substring(c_phone from 1 for 2) in + ('13','31','23','29','30','18','17') + and c_acctbal > ( select - * + avg(c_acctbal) from - {lineitem} l2 + {customer} where - l2.l_orderkey = l1.l_orderkey - and l2.l_suppkey <> l1.l_suppkey + c_acctbal > 0.00 + and substring (c_phone from 1 for 2) in + ('13','31','23','29','30','18','17') ) and not exists ( select * from - {lineitem} l3 + {orders} where - l3.l_orderkey = l1.l_orderkey - and l3.l_suppkey <> l1.l_suppkey - and l3.l_receiptdate > l3.l_commitdate + o_custkey = c_custkey ) - and n_name = 'SAUDI ARABIA' - group by - s_name - order by - numwait desc, - s_name - limit 100 - """, - - # Edited: - # - implicit joins generated some condition=[true] on Blazingsql - # - by comparing with Pyspark there is no needed - # of adding more table aliases + ) as custsale + group by + cntrycode + order by + cntrycode + """ + + # WARNING: Join condition is currently not supported +} + + +def map_tables(tpch_query, tables = {}): + tpch_tables = { + 'nation': tables.get("nation", "nation"), + 'region': tables.get("region", "region"), + 'customer': tables.get("customer", "customer"), + 'lineitem': tables.get("lineitem", "lineitem"), + 'orders': tables.get("orders", "orders"), + 'supplier': tables.get("supplier", "supplier"), + 'part': tables.get("part", "part"), + 'partsupp': tables.get("partsupp", "partsupp") + } - "TEST_22": f""" - select - cntrycode, - count(*) as numcust, - sum(c_acctbal) as totacctbal - from ( - select - substring(c_phone from 1 for 2) as cntrycode, - c_acctbal - from - {customer} - where - substring(c_phone from 1 for 2) in - ('13','31','23','29','30','18','17') - and c_acctbal > ( - select - avg(c_acctbal) - from - {customer} - where - c_acctbal > 0.00 - and substring (c_phone from 1 for 2) in - ('13','31','23','29','30','18','17') - ) - and not exists ( - select - * - from - {orders} - where - o_custkey = c_custkey - ) - ) as custsale - group by - cntrycode - order by - cntrycode - """ + return tpch_query.format(**tpch_tables) - # WARNING: Join condition is currently not supported - } - return queries.get(test_id) +def get_tpch_query(test_id, tables = {}): + return map_tables(query_templates.get(test_id), tables) diff --git a/tests/BlazingSQLTest/Utils/__init__.py b/tests/BlazingSQLTest/Utils/__init__.py index ea8416777..3e025f98d 100644 --- a/tests/BlazingSQLTest/Utils/__init__.py +++ b/tests/BlazingSQLTest/Utils/__init__.py @@ -29,6 +29,7 @@ def test_name(queryType, fileSchemaType): def skip_test(dask_client, nRals, fileSchemaType, queryType): testsWithNulls = Settings.data["RunSettings"]["testsWithNulls"] + executionMode = Settings.data['RunSettings']['executionMode'] if fileSchemaType == DataType.DASK_CUDF: # Skipping combination DASK_CUDF and testsWithNulls="true" @@ -47,6 +48,10 @@ def skip_test(dask_client, nRals, fileSchemaType, queryType): return skip + if executionMode == 'gpuci': + if fileSchemaType in [DataType.MYSQL, DataType.POSTGRESQL]: + return True + return False diff --git a/tests/README.md b/tests/README.md index bacc0b4d0..e40094b6e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -133,8 +133,6 @@ BLAZINGSQL_E2E_TEST_WITH_NULLS=true BLAZINGSQL_E2E_EXEC_MODE="generator" ./tes ### Unit tests - - ```shell-script cd blazingsql @@ -252,12 +250,13 @@ We provide as well a copy of the Apache Hive software (tested with version 1.2.2 ``` #### MySQL, PostgreSQL, SQLite testing -For MySQL you will need to install this lib: +For MySQL and PostgreSQL you will need to install these libs: ```shell-script conda install -c conda-forge mysql-connector-python +conda install -c conda-forge psycopg2 ``` -and run the following line into mysql console: +and run the following line in the MySQL console: ```sql SET GLOBAL local_infile = 'ON'; ``` @@ -265,11 +264,22 @@ SET GLOBAL local_infile = 'ON'; To run the tests for tables from other SQL databases just define these env vars before run the test: ```shell-script -BLAZINGSQL_E2E_SQL_HOSTNAME -BLAZINGSQL_E2E_SQL_PORT -BLAZINGSQL_E2E_SQL_USERNAME -BLAZINGSQL_E2E_SQL_PASSWORD -BLAZINGSQL_E2E_SQL_SCHEMA +# for MySQL +BLAZINGSQL_E2E_MYSQL_HOSTNAME +BLAZINGSQL_E2E_MYSQL_PORT +BLAZINGSQL_E2E_MYSQL_USERNAME +BLAZINGSQL_E2E_MYSQL_PASSWORD +BLAZINGSQL_E2E_MYSQL_DATABASE + +# for PostgreSQL +BLAZINGSQL_E2E_POSTGRESQL_HOSTNAME +BLAZINGSQL_E2E_POSTGRESQL_PORT +BLAZINGSQL_E2E_POSTGRESQL_USERNAME +BLAZINGSQL_E2E_POSTGRESQL_PASSWORD +BLAZINGSQL_E2E_POSTGRESQL_DATABASE + +# for SQLite +BLAZINGSQL_E2E_SQLITE_DATABASE ``` Note BLAZINGSQL_E2E_SQL_PORT is a number and the other vars are strings!