Skip to content

Commit

Permalink
Merge pull request #147 from Mytherin/issue146
Browse files Browse the repository at this point in the history
Fix #146 by materializing child postgres scans for CREATE TABLE AS as well and fix formatting
  • Loading branch information
Mytherin authored Dec 11, 2023
2 parents ccba10d + 5be679d commit bbb93b5
Show file tree
Hide file tree
Showing 29 changed files with 295 additions and 227 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ test_debug: debug

format:
cp duckdb/.clang-format .
find src/ -iname *.hpp -o -iname *.cpp | xargs clang-format --sort-includes=0 -style=file -i
find src/ -iname "*.hpp" -o -iname "*.cpp" | xargs clang-format --sort-includes=0 -style=file -i
cmake-format -i CMakeLists.txt
rm .clang-format

Expand Down
8 changes: 4 additions & 4 deletions src/postgres_attach.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ static void AttachFunction(ClientContext &context, TableFunctionInput &data_p, D
auto conn = PostgresConnection::Open(data.dsn);
auto dconn = Connection(context.db->GetDatabase(context));
auto fetch_table_query = StringUtil::Format(
R"(
R"(
SELECT relname
FROM pg_class JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid
JOIN pg_attribute ON pg_class.oid = pg_attribute.attrelid
WHERE relkind = 'r' AND attnum > 0 AND nspname = %s
GROUP BY relname
ORDER BY relname;
)",
KeywordHelper::WriteQuoted(data.source_schema));
KeywordHelper::WriteQuoted(data.source_schema));
auto res = conn.Query(fetch_table_query);
for (idx_t row = 0; row < PQntuples(res->res); row++) {
auto table_name = res->GetString(row, 0);
Expand Down Expand Up @@ -93,7 +93,7 @@ ORDER BY relname;
}

PostgresAttachFunction::PostgresAttachFunction()
: TableFunction("postgres_attach", {LogicalType::VARCHAR}, AttachFunction, AttachBind) {
: TableFunction("postgres_attach", {LogicalType::VARCHAR}, AttachFunction, AttachBind) {
named_parameters["overwrite"] = LogicalType::BOOLEAN;
named_parameters["filter_pushdown"] = LogicalType::BOOLEAN;

Expand All @@ -102,4 +102,4 @@ PostgresAttachFunction::PostgresAttachFunction()
named_parameters["suffix"] = LogicalType::VARCHAR;
}

}
} // namespace duckdb
14 changes: 8 additions & 6 deletions src/postgres_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static bool ResultHasError(PGresult *result) {
if (!result) {
return true;
}
switch(PQresultStatus(result)) {
switch (PQresultStatus(result)) {
case PGRES_COMMAND_OK:
case PGRES_TUPLES_OK:
return false;
Expand All @@ -54,12 +54,12 @@ PGresult *PostgresConnection::PQExecute(const string &query) {
return PQexec(GetConn(), query.c_str());
}


unique_ptr<PostgresResult> PostgresConnection::TryQuery(const string &query, optional_ptr<string> error_message) {
auto result = PQExecute(query.c_str());
if (ResultHasError(result)) {
if (error_message) {
*error_message = StringUtil::Format("Failed to execute query \"" + query + "\": " + string(PQresultErrorMessage(result)));
*error_message = StringUtil::Format("Failed to execute query \"" + query +
"\": " + string(PQresultErrorMessage(result)));
}
return nullptr;
}
Expand Down Expand Up @@ -88,13 +88,14 @@ vector<unique_ptr<PostgresResult>> PostgresConnection::ExecuteQueries(const stri
throw std::runtime_error("Failed to execute query \"" + queries + "\": " + string(PQerrorMessage(GetConn())));
}
vector<unique_ptr<PostgresResult>> results;
while(true) {
while (true) {
auto res = PQgetResult(GetConn());
if (!res) {
break;
}
if (ResultHasError(res)) {
throw std::runtime_error("Failed to execute query \"" + queries + "\": " + string(PQresultErrorMessage(res)));
throw std::runtime_error("Failed to execute query \"" + queries +
"\": " + string(PQresultErrorMessage(res)));
}
if (PQresultStatus(res) != PGRES_TUPLES_OK) {
continue;
Expand All @@ -106,7 +107,8 @@ vector<unique_ptr<PostgresResult>> PostgresConnection::ExecuteQueries(const stri
}

PostgresVersion PostgresConnection::GetPostgresVersion() {
auto result = Query("SELECT CURRENT_SETTING('server_version'), (SELECT COUNT(*) FROM pg_settings WHERE name LIKE 'rds%')");
auto result =
Query("SELECT CURRENT_SETTING('server_version'), (SELECT COUNT(*) FROM pg_settings WHERE name LIKE 'rds%')");
auto version = PostgresUtils::ExtractPostgresVersion(result->GetString(0, 0));
if (result->GetInt64(0, 1) > 0) {
version.type_v = PostgresInstanceType::AURORA;
Expand Down
2 changes: 1 addition & 1 deletion src/postgres_copy_from.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ void PostgresConnection::BeginCopyFrom(PostgresBinaryReader &reader, const strin
reader.CheckHeader();
}

}
} // namespace duckdb
38 changes: 20 additions & 18 deletions src/postgres_copy_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

namespace duckdb {

void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &state, PostgresCopyFormat format, const string &schema_name, const string &table_name, const vector<string> &column_names) {
void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &state, PostgresCopyFormat format,
const string &schema_name, const string &table_name,
const vector<string> &column_names) {
string query = "COPY ";
if (!schema_name.empty()) {
query += KeywordHelper::WriteQuoted(schema_name, '"') + ".";
}
query += KeywordHelper::WriteQuoted(table_name, '"') + " ";
if (!column_names.empty()) {
query += "(";
for(idx_t c = 0; c < column_names.size(); c++) {
for (idx_t c = 0; c < column_names.size(); c++) {
if (c > 0) {
query += ", ";
}
Expand All @@ -23,7 +25,7 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
}
query += "FROM STDIN (FORMAT ";
state.format = format;
switch(state.format) {
switch (state.format) {
case PostgresCopyFormat::BINARY:
query += "BINARY";
break;
Expand All @@ -50,8 +52,8 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
void PostgresConnection::CopyData(data_ptr_t buffer, idx_t size) {
int result;
do {
result = PQputCopyData(GetConn(), (const char *) buffer, int(size));
} while(result == 0);
result = PQputCopyData(GetConn(), (const char *)buffer, int(size));
} while (result == 0);
if (result == -1) {
throw InternalException("Error during PQputCopyData: %s", PQerrorMessage(GetConn()));
}
Expand Down Expand Up @@ -103,15 +105,15 @@ void CastListToPostgresArray(ClientContext &context, Vector &input, Vector &varc
auto child_entries = FlatVector::GetData<string_t>(child_varchar);
auto list_entries = FlatVector::GetData<list_entry_t>(input);
auto result_entries = FlatVector::GetData<string_t>(varchar_vector);
for(idx_t r = 0; r < size; r++) {
for (idx_t r = 0; r < size; r++) {
if (FlatVector::IsNull(input, r)) {
FlatVector::SetNull(varchar_vector, r, true);
continue;
}
auto list_entry = list_entries[r];
string result;
result = "{";
for(idx_t list_idx = 0; list_idx < list_entry.length; list_idx++) {
for (idx_t list_idx = 0; list_idx < list_entry.length; list_idx++) {
if (list_idx > 0) {
result += ",";
}
Expand All @@ -134,7 +136,7 @@ void CastListToPostgresArray(ClientContext &context, Vector &input, Vector &varc
}

bool TypeRequiresQuotes(const LogicalType &input) {
switch(input.id()) {
switch (input.id()) {
case LogicalTypeId::STRUCT:
case LogicalTypeId::LIST:
return true;
Expand All @@ -148,7 +150,7 @@ void CastStructToPostgres(ClientContext &context, Vector &input, Vector &varchar
// cast child data of structs
vector<Vector> child_varchar_vectors;
vector<bool> child_requires_quotes;
for(idx_t c = 0; c < child_vectors.size(); c++) {
for (idx_t c = 0; c < child_vectors.size(); c++) {
Vector child_varchar(LogicalType::VARCHAR, size);
CastToPostgresVarchar(context, *child_vectors[c], child_varchar, size, depth + 1);
child_varchar_vectors.push_back(std::move(child_varchar));
Expand All @@ -157,14 +159,14 @@ void CastStructToPostgres(ClientContext &context, Vector &input, Vector &varchar

// construct the struct entries
auto result_entries = FlatVector::GetData<string_t>(varchar_vector);
for(idx_t r = 0; r < size; r++) {
for (idx_t r = 0; r < size; r++) {
if (FlatVector::IsNull(input, r)) {
FlatVector::SetNull(varchar_vector, r, true);
continue;
}
string result;
result = "(";
for(idx_t c = 0; c < child_varchar_vectors.size(); c++) {
for (idx_t c = 0; c < child_varchar_vectors.size(); c++) {
if (c > 0) {
result += ",";
}
Expand All @@ -189,7 +191,7 @@ void CastStructToPostgres(ClientContext &context, Vector &input, Vector &varchar
void CastBlobToPostgres(ClientContext &context, Vector &input, Vector &result, idx_t size) {
auto input_data = FlatVector::GetData<string_t>(input);
auto result_data = FlatVector::GetData<string_t>(result);
for(idx_t r = 0; r < size; r++) {
for (idx_t r = 0; r < size; r++) {
if (FlatVector::IsNull(input, r)) {
FlatVector::SetNull(result, r, true);
continue;
Expand All @@ -198,7 +200,7 @@ void CastBlobToPostgres(ClientContext &context, Vector &input, Vector &result, i
string blob_str = "\\\\x";
auto blob_data = const_data_ptr_cast(input_data[r].GetData());
auto blob_size = input_data[r].GetSize();
for(idx_t c = 0; c < blob_size; c++) {
for (idx_t c = 0; c < blob_size; c++) {
blob_str += HEX_STRING[blob_data[c] / 16];
blob_str += HEX_STRING[blob_data[c] % 16];
}
Expand All @@ -223,7 +225,8 @@ void CastToPostgresVarchar(ClientContext &context, Vector &input, Vector &result
}
}

void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &state, DataChunk &chunk, DataChunk &varchar_chunk) {
void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &state, DataChunk &chunk,
DataChunk &varchar_chunk) {
chunk.Flatten();

if (state.format == PostgresCopyFormat::BINARY) {
Expand All @@ -242,14 +245,14 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st
if (varchar_chunk.ColumnCount() == 0) {
// not initialized yet
vector<LogicalType> varchar_types;
for(idx_t c = 0; c < chunk.ColumnCount(); c++) {
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
varchar_types.push_back(LogicalType::VARCHAR);
}
varchar_chunk.Initialize(Allocator::DefaultAllocator(), varchar_types);
}
D_ASSERT(chunk.ColumnCount() == varchar_chunk.ColumnCount());
// for text format cast to varchar first
for(idx_t c = 0; c < chunk.ColumnCount(); c++) {
for (idx_t c = 0; c < chunk.ColumnCount(); c++) {
CastToPostgresVarchar(context, chunk.data[c], varchar_chunk.data[c], chunk.size());
}
varchar_chunk.SetCardinality(chunk.size());
Expand All @@ -270,5 +273,4 @@ void PostgresConnection::CopyChunk(ClientContext &context, PostgresCopyState &st
}
}


}
} // namespace duckdb
28 changes: 17 additions & 11 deletions src/postgres_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#define DUCKDB_BUILD_LOADABLE_EXTENSION
#include "duckdb.hpp"


#include "postgres_scanner.hpp"
#include "postgres_storage.hpp"
#include "postgres_scanner_extension.hpp"
Expand All @@ -20,7 +19,7 @@ static void SetPostgresConnectionLimit(ClientContext &context, SetScope scope, V
throw InvalidInputException("pg_connection_limit can only be set globally");
}
auto databases = DatabaseManager::Get(context).GetDatabases(context);
for(auto &db_ref : databases) {
for (auto &db_ref : databases) {
auto &db = db_ref.get();
auto &catalog = db.GetCatalog();
if (catalog.GetCatalogType() != "postgres") {
Expand Down Expand Up @@ -55,17 +54,25 @@ static void LoadInternal(DatabaseInstance &db) {
auto &config = DBConfig::GetConfig(db);
config.storage_extensions["postgres_scanner"] = make_uniq<PostgresStorageExtension>();

config.AddExtensionOption("pg_use_binary_copy", "Whether or not to use BINARY copy to read data", LogicalType::BOOLEAN, Value::BOOLEAN(true));
config.AddExtensionOption("pg_pages_per_task", "The amount of pages per task", LogicalType::UBIGINT, Value::UBIGINT(PostgresBindData::DEFAULT_PAGES_PER_TASK));
config.AddExtensionOption("pg_connection_limit", "The maximum amount of concurrent Postgres connections", LogicalType::UBIGINT, Value::UBIGINT(PostgresConnectionPool::DEFAULT_MAX_CONNECTIONS), SetPostgresConnectionLimit);
config.AddExtensionOption("pg_array_as_varchar", "Read Postgres arrays as varchar - enables reading mixed dimensional arrays", LogicalType::BOOLEAN, Value::BOOLEAN(false));
config.AddExtensionOption("pg_experimental_filter_pushdown", "Whether or not to use filter pushdown (currently experimental)", LogicalType::BOOLEAN, Value::BOOLEAN(false));
config.AddExtensionOption("pg_debug_show_queries", "DEBUG SETTING: print all queries sent to Postgres to stdout", LogicalType::BOOLEAN, Value::BOOLEAN(false), SetPostgresDebugQueryPrint);

config.AddExtensionOption("pg_use_binary_copy", "Whether or not to use BINARY copy to read data",
LogicalType::BOOLEAN, Value::BOOLEAN(true));
config.AddExtensionOption("pg_pages_per_task", "The amount of pages per task", LogicalType::UBIGINT,
Value::UBIGINT(PostgresBindData::DEFAULT_PAGES_PER_TASK));
config.AddExtensionOption("pg_connection_limit", "The maximum amount of concurrent Postgres connections",
LogicalType::UBIGINT, Value::UBIGINT(PostgresConnectionPool::DEFAULT_MAX_CONNECTIONS),
SetPostgresConnectionLimit);
config.AddExtensionOption("pg_array_as_varchar",
"Read Postgres arrays as varchar - enables reading mixed dimensional arrays",
LogicalType::BOOLEAN, Value::BOOLEAN(false));
config.AddExtensionOption("pg_experimental_filter_pushdown",
"Whether or not to use filter pushdown (currently experimental)", LogicalType::BOOLEAN,
Value::BOOLEAN(false));
config.AddExtensionOption("pg_debug_show_queries", "DEBUG SETTING: print all queries sent to Postgres to stdout",
LogicalType::BOOLEAN, Value::BOOLEAN(false), SetPostgresDebugQueryPrint);
}

void PostgresScannerExtension::Load(DuckDB &db) {
LoadInternal(*db.instance);
LoadInternal(*db.instance);
}

extern "C" {
Expand All @@ -81,5 +88,4 @@ DUCKDB_EXTENSION_API const char *postgres_scanner_version() {
DUCKDB_EXTENSION_API void postgres_scanner_storage_init(DBConfig &config) {
config.storage_extensions["postgres_scanner"] = make_uniq<PostgresStorageExtension>();
}

}
8 changes: 5 additions & 3 deletions src/postgres_filter_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

namespace duckdb {

string PostgresFilterPushdown::CreateExpression(string &column_name, vector<unique_ptr<TableFilter>> &filters, string op) {
string PostgresFilterPushdown::CreateExpression(string &column_name, vector<unique_ptr<TableFilter>> &filters,
string op) {
vector<string> filter_entries;
for (auto &filter : filters) {
filter_entries.push_back(TransformFilter(column_name, *filter));
Expand Down Expand Up @@ -55,7 +56,8 @@ string PostgresFilterPushdown::TransformFilter(string &column_name, TableFilter
}
}

string PostgresFilterPushdown::TransformFilters(const vector<column_t> &column_ids, optional_ptr<TableFilterSet> filters, const vector<string> &names) {
string PostgresFilterPushdown::TransformFilters(const vector<column_t> &column_ids,
optional_ptr<TableFilterSet> filters, const vector<string> &names) {
if (!filters || filters->filters.empty()) {
// no filters
return string();
Expand All @@ -72,4 +74,4 @@ string PostgresFilterPushdown::TransformFilters(const vector<column_t> &column_i
return result;
}

}
} // namespace duckdb
13 changes: 7 additions & 6 deletions src/postgres_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace duckdb {

static unique_ptr<FunctionData> PGQueryBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
vector<LogicalType> &return_types, vector<string> &names) {
auto result = make_uniq<PostgresBindData>();

// look up the database to query
Expand Down Expand Up @@ -48,9 +48,11 @@ static unique_ptr<FunctionData> PGQueryBind(ClientContext &context, TableFunctio
}
auto nfields = PQnfields(describe_prepared);
if (nfields <= 0) {
throw BinderException("No fields returned by query \"%s\" - the query must be a SELECT statement that returns at least one column", sql);
throw BinderException("No fields returned by query \"%s\" - the query must be a SELECT statement that returns "
"at least one column",
sql);
}
for(idx_t c = 0; c < nfields; c++) {
for (idx_t c = 0; c < nfields; c++) {
PostgresType postgres_type;
postgres_type.oid = PQftype(describe_prepared, c);
PostgresTypeData type_data;
Expand All @@ -75,12 +77,11 @@ static unique_ptr<FunctionData> PGQueryBind(ClientContext &context, TableFunctio
}

PostgresQueryFunction::PostgresQueryFunction()
: TableFunction("postgres_query", {LogicalType::VARCHAR, LogicalType::VARCHAR},
nullptr, PGQueryBind) {
: TableFunction("postgres_query", {LogicalType::VARCHAR, LogicalType::VARCHAR}, nullptr, PGQueryBind) {
PostgresScanFunction scan_function;
init_global = scan_function.init_global;
init_local = scan_function.init_local;
function = scan_function.function;
projection_pushdown = true;
}
}
} // namespace duckdb
Loading

0 comments on commit bbb93b5

Please sign in to comment.