Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Change primary key syntax to work for a multi column primary key #11

Merged
merged 5 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion includes/sql_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void create_schema(duckdb::Connection &con, const std::string &db_name,
bool table_exists(duckdb::Connection &con, const table_def &table);

void create_table(duckdb::Connection &con, const table_def &table,
const std::vector<column_def> &columns);
const std::vector<const column_def *> &columns_pk,
const std::vector<column_def> &all_columns);

std::vector<column_def> describe_table(duckdb::Connection &con,
const table_def &table);
Expand Down
27 changes: 19 additions & 8 deletions src/motherduck_destination_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ void process_file(
arrow_array_stream.release(&arrow_array_stream);
}

void find_primary_keys(
const std::vector<column_def> &cols,
std::vector<const column_def *> &columns_pk,
std::vector<const column_def *> *columns_regular = nullptr) {
for (auto &col : cols) {
if (col.primary_key) {
columns_pk.push_back(&col);
} else if (columns_regular != nullptr) {
columns_regular->push_back(&col);
}
}
}

grpc::Status DestinationSdkImpl::ConfigurationForm(
::grpc::ServerContext *context,
const ::fivetran_sdk::ConfigurationFormRequest *request,
Expand Down Expand Up @@ -204,7 +217,10 @@ grpc::Status DestinationSdkImpl::CreateTable(
create_schema(*con, db_name, schema_name);
}

create_table(*con, table, get_duckdb_columns(request->table().columns()));
std::vector<const column_def *> columns_pk;
const auto cols = get_duckdb_columns(request->table().columns());
find_primary_keys(cols, columns_pk);
create_table(*con, table, columns_pk, cols);
response->set_success(true);
} catch (const std::exception &e) {
mdlog::severe("CreateTable endpoint failed for schema <" +
Expand Down Expand Up @@ -297,13 +313,8 @@ DestinationSdkImpl::WriteBatch(::grpc::ServerContext *context,
const auto cols = get_duckdb_columns(request->table().columns());
std::vector<const column_def *> columns_pk;
std::vector<const column_def *> columns_regular;
for (auto &col : cols) {
if (col.primary_key) {
columns_pk.push_back(&col);
} else {
columns_regular.push_back(&col);
}
}
find_primary_keys(cols, columns_pk, &columns_regular);

if (columns_pk.empty()) {
throw std::invalid_argument("No primary keys found");
}
Expand Down
63 changes: 32 additions & 31 deletions src/sql_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ std::string table_def::to_string() const {
return out.str();
}

const auto print_column = [](const std::string &quoted_col,
std::ostringstream &out) { out << quoted_col; };

void write_joined(
std::ostringstream &sql, const std::vector<const column_def *> columns,
std::ostringstream &sql, const std::vector<const column_def *> &columns,
std::function<void(const std::string &, std::ostringstream &)> print_str) {
bool first = true;
for (const auto &col : columns) {
Expand All @@ -26,7 +29,7 @@ void write_joined(
} else {
sql << ", ";
}
print_str(col->name, sql);
print_str(KeywordHelper::WriteQuoted(col->name, '"'), sql);
}
}

Expand Down Expand Up @@ -78,20 +81,24 @@ void create_schema(duckdb::Connection &con, const std::string &db_name,
}

void create_table(duckdb::Connection &con, const table_def &table,
const std::vector<column_def> &columns) {
const std::vector<const column_def *> &columns_pk,
const std::vector<column_def> &all_columns) {
const std::string absolute_table_name = table.to_string();
std::ostringstream ddl;
ddl << "CREATE OR REPLACE TABLE " << absolute_table_name << " (";

for (const auto &col : columns) {
for (const auto &col : all_columns) {
ddl << KeywordHelper::WriteQuoted(col.name, '"') << " "
<< duckdb::EnumUtil::ToChars(col.type);
if (col.primary_key) {
ddl << " PRIMARY KEY";
}
ddl << ", "; // DuckDB allows trailing commas
}

if (!columns_pk.empty()) {
ddl << "PRIMARY KEY (";
write_joined(ddl, columns_pk, print_column);
ddl << ")";
}

ddl << ")";

auto query = ddl.str();
Expand Down Expand Up @@ -232,15 +239,12 @@ void upsert(duckdb::Connection &con, const table_def &table,
<< staging_table_name;
if (!columns_pk.empty()) {
sql << " ON CONFLICT (";
write_joined(
sql, columns_pk,
[](const std::string &name, std::ostringstream &out) { out << name; });
write_joined(sql, columns_pk, print_column);
sql << " ) DO UPDATE SET ";

write_joined(sql, columns_regular,
[](const std::string &name, std::ostringstream &out) {
out << KeywordHelper::WriteQuoted(name, '"') << " = "
<< "excluded." << KeywordHelper::WriteQuoted(name, '"');
[](const std::string &quoted_col, std::ostringstream &out) {
out << quoted_col << " = excluded." << quoted_col;
});
}

Expand All @@ -266,23 +270,21 @@ void update_values(duckdb::Connection &con, const table_def &table,

write_joined(sql, columns_regular,
[staging_table_name, absolute_table_name, unmodified_string](
const std::string name, std::ostringstream &out) {
auto colname = KeywordHelper::WriteQuoted(name, '"');
out << colname << " = CASE WHEN " << staging_table_name << "."
<< colname << " = "
const std::string quoted_col, std::ostringstream &out) {
out << quoted_col << " = CASE WHEN " << staging_table_name
<< "." << quoted_col << " = "
<< KeywordHelper::WriteQuoted(unmodified_string, '\'')
<< " THEN " << absolute_table_name << "." << colname
<< " ELSE " << staging_table_name << "." << colname
<< " THEN " << absolute_table_name << "." << quoted_col
<< " ELSE " << staging_table_name << "." << quoted_col
<< " END";
});

sql << " FROM " << staging_table_name << " WHERE ";
write_joined(
sql, columns_pk, [&](const std::string &pk, std::ostringstream &out) {
out << table.table_name << "." << KeywordHelper::WriteQuoted(pk, '"')
<< " = " << staging_table_name << "."
<< KeywordHelper::WriteQuoted(pk, '"');
});
write_joined(sql, columns_pk,
[&](const std::string &quoted_col, std::ostringstream &out) {
out << table.table_name << "." << quoted_col << " = "
<< staging_table_name << "." << quoted_col;
});

auto query = sql.str();
mdlog::info("update: " + query);
Expand All @@ -302,12 +304,11 @@ void delete_rows(duckdb::Connection &con, const table_def &table,
sql << "DELETE FROM " + absolute_table_name << " USING " << staging_table_name
<< " WHERE ";

write_joined(
sql, columns_pk, [&](const std::string &pk, std::ostringstream &out) {
out << table.table_name << "." << KeywordHelper::WriteQuoted(pk, '"')
<< " = " << staging_table_name << "."
<< KeywordHelper::WriteQuoted(pk, '"');
});
write_joined(sql, columns_pk,
[&](const std::string &quoted_col, std::ostringstream &out) {
out << table.table_name << "." << quoted_col << " = "
<< staging_table_name << "." << quoted_col;
});

auto query = sql.str();
mdlog::info("delete_rows: " + query);
Expand Down
48 changes: 48 additions & 0 deletions test/integration/test_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,52 @@ TEST_CASE("Truncate nonexistent table should succeed", "[integration]") {
REQUIRE_THAT(buffer.str(), Catch::Matchers::ContainsSubstring(
"Table <nonexistent> not found in schema "
"<some_schema>; not truncated"));
}

TEST_CASE("CreateTable with multiple primary keys", "[integration]") {
DestinationSdkImpl service;

const std::string table_name =
"multikey_table" + std::to_string(Catch::rngSeed());
auto token = std::getenv("motherduck_token");
REQUIRE(token);

{
// Create Table
::fivetran_sdk::CreateTableRequest request;
(*request.mutable_configuration())["motherduck_token"] = token;
(*request.mutable_configuration())["motherduck_database"] = "fivetran_test";
request.mutable_table()->set_name(table_name);
auto col1 = request.mutable_table()->add_columns();
col1->set_name("id1");
col1->set_type(::fivetran_sdk::DataType::INT);
col1->set_primary_key(true);
auto col2 = request.mutable_table()->add_columns();
col2->set_name("id2");
col2->set_type(::fivetran_sdk::DataType::INT);
col2->set_primary_key(true);

::fivetran_sdk::CreateTableResponse response;
auto status = service.CreateTable(nullptr, &request, &response);

INFO(status.error_message());
REQUIRE(status.ok());
}

{
// Describe the created table
::fivetran_sdk::DescribeTableRequest request;
(*request.mutable_configuration())["motherduck_token"] = token;
(*request.mutable_configuration())["motherduck_database"] = "fivetran_test";
request.set_table_name(table_name);

{
::fivetran_sdk::DescribeTableResponse response;
auto status = service.DescribeTable(nullptr, &request, &response);

INFO(status.error_message());
REQUIRE(status.ok());
REQUIRE(response.table().columns().size() == 2);
}
}
}
Loading