diff --git a/includes/sql_generator.hpp b/includes/sql_generator.hpp index b66e2a3..ab9414b 100644 --- a/includes/sql_generator.hpp +++ b/includes/sql_generator.hpp @@ -27,7 +27,8 @@ void create_schema(duckdb::Connection &con, const std::string &db_name, bool table_exists(duckdb::Connection &con, const table_def &table); void create_table(duckdb::Connection &con, const table_def &table, - const std::vector &columns); + const std::vector &columns_pk, + const std::vector &all_columns); std::vector describe_table(duckdb::Connection &con, const table_def &table); diff --git a/src/motherduck_destination_server.cpp b/src/motherduck_destination_server.cpp index f06560e..4702616 100644 --- a/src/motherduck_destination_server.cpp +++ b/src/motherduck_destination_server.cpp @@ -113,6 +113,19 @@ void process_file( arrow_array_stream.release(&arrow_array_stream); } +void find_primary_keys( + const std::vector &cols, + std::vector &columns_pk, + std::vector *columns_regular = nullptr) { + for (auto &col : cols) { + if (col.primary_key) { + columns_pk.push_back(&col); + } else if (columns_regular != nullptr) { + columns_regular->push_back(&col); + } + } +} + grpc::Status DestinationSdkImpl::ConfigurationForm( ::grpc::ServerContext *context, const ::fivetran_sdk::ConfigurationFormRequest *request, @@ -204,7 +217,10 @@ grpc::Status DestinationSdkImpl::CreateTable( create_schema(*con, db_name, schema_name); } - create_table(*con, table, get_duckdb_columns(request->table().columns())); + std::vector columns_pk; + const auto cols = get_duckdb_columns(request->table().columns()); + find_primary_keys(cols, columns_pk); + create_table(*con, table, columns_pk, cols); response->set_success(true); } catch (const std::exception &e) { mdlog::severe("CreateTable endpoint failed for schema <" + @@ -297,13 +313,8 @@ DestinationSdkImpl::WriteBatch(::grpc::ServerContext *context, const auto cols = get_duckdb_columns(request->table().columns()); std::vector columns_pk; std::vector columns_regular; - for (auto &col : cols) { - if (col.primary_key) { - columns_pk.push_back(&col); - } else { - columns_regular.push_back(&col); - } - } + find_primary_keys(cols, columns_pk, &columns_regular); + if (columns_pk.empty()) { throw std::invalid_argument("No primary keys found"); } diff --git a/src/sql_generator.cpp b/src/sql_generator.cpp index 4bfe243..08d697a 100644 --- a/src/sql_generator.cpp +++ b/src/sql_generator.cpp @@ -16,8 +16,11 @@ std::string table_def::to_string() const { return out.str(); } +const auto print_column = [](const std::string "ed_col, + std::ostringstream &out) { out << quoted_col; }; + void write_joined( - std::ostringstream &sql, const std::vector columns, + std::ostringstream &sql, const std::vector &columns, std::function print_str) { bool first = true; for (const auto &col : columns) { @@ -26,7 +29,7 @@ void write_joined( } else { sql << ", "; } - print_str(col->name, sql); + print_str(KeywordHelper::WriteQuoted(col->name, '"'), sql); } } @@ -78,20 +81,24 @@ void create_schema(duckdb::Connection &con, const std::string &db_name, } void create_table(duckdb::Connection &con, const table_def &table, - const std::vector &columns) { + const std::vector &columns_pk, + const std::vector &all_columns) { const std::string absolute_table_name = table.to_string(); std::ostringstream ddl; ddl << "CREATE OR REPLACE TABLE " << absolute_table_name << " ("; - for (const auto &col : columns) { + for (const auto &col : all_columns) { ddl << KeywordHelper::WriteQuoted(col.name, '"') << " " << duckdb::EnumUtil::ToChars(col.type); - if (col.primary_key) { - ddl << " PRIMARY KEY"; - } ddl << ", "; // DuckDB allows trailing commas } + if (!columns_pk.empty()) { + ddl << "PRIMARY KEY ("; + write_joined(ddl, columns_pk, print_column); + ddl << ")"; + } + ddl << ")"; auto query = ddl.str(); @@ -232,15 +239,12 @@ void upsert(duckdb::Connection &con, const table_def &table, << staging_table_name; if (!columns_pk.empty()) { sql << " ON CONFLICT ("; - write_joined( - sql, columns_pk, - [](const std::string &name, std::ostringstream &out) { out << name; }); + write_joined(sql, columns_pk, print_column); sql << " ) DO UPDATE SET "; write_joined(sql, columns_regular, - [](const std::string &name, std::ostringstream &out) { - out << KeywordHelper::WriteQuoted(name, '"') << " = " - << "excluded." << KeywordHelper::WriteQuoted(name, '"'); + [](const std::string "ed_col, std::ostringstream &out) { + out << quoted_col << " = excluded." << quoted_col; }); } @@ -266,23 +270,21 @@ void update_values(duckdb::Connection &con, const table_def &table, write_joined(sql, columns_regular, [staging_table_name, absolute_table_name, unmodified_string]( - const std::string name, std::ostringstream &out) { - auto colname = KeywordHelper::WriteQuoted(name, '"'); - out << colname << " = CASE WHEN " << staging_table_name << "." - << colname << " = " + const std::string quoted_col, std::ostringstream &out) { + out << quoted_col << " = CASE WHEN " << staging_table_name + << "." << quoted_col << " = " << KeywordHelper::WriteQuoted(unmodified_string, '\'') - << " THEN " << absolute_table_name << "." << colname - << " ELSE " << staging_table_name << "." << colname + << " THEN " << absolute_table_name << "." << quoted_col + << " ELSE " << staging_table_name << "." << quoted_col << " END"; }); sql << " FROM " << staging_table_name << " WHERE "; - write_joined( - sql, columns_pk, [&](const std::string &pk, std::ostringstream &out) { - out << table.table_name << "." << KeywordHelper::WriteQuoted(pk, '"') - << " = " << staging_table_name << "." - << KeywordHelper::WriteQuoted(pk, '"'); - }); + write_joined(sql, columns_pk, + [&](const std::string "ed_col, std::ostringstream &out) { + out << table.table_name << "." << quoted_col << " = " + << staging_table_name << "." << quoted_col; + }); auto query = sql.str(); mdlog::info("update: " + query); @@ -302,12 +304,11 @@ void delete_rows(duckdb::Connection &con, const table_def &table, sql << "DELETE FROM " + absolute_table_name << " USING " << staging_table_name << " WHERE "; - write_joined( - sql, columns_pk, [&](const std::string &pk, std::ostringstream &out) { - out << table.table_name << "." << KeywordHelper::WriteQuoted(pk, '"') - << " = " << staging_table_name << "." - << KeywordHelper::WriteQuoted(pk, '"'); - }); + write_joined(sql, columns_pk, + [&](const std::string "ed_col, std::ostringstream &out) { + out << table.table_name << "." << quoted_col << " = " + << staging_table_name << "." << quoted_col; + }); auto query = sql.str(); mdlog::info("delete_rows: " + query); diff --git a/test/integration/test_server.cpp b/test/integration/test_server.cpp index 5cbbee0..cb24f64 100644 --- a/test/integration/test_server.cpp +++ b/test/integration/test_server.cpp @@ -501,4 +501,52 @@ TEST_CASE("Truncate nonexistent table should succeed", "[integration]") { REQUIRE_THAT(buffer.str(), Catch::Matchers::ContainsSubstring( "Table not found in schema " "; not truncated")); +} + +TEST_CASE("CreateTable with multiple primary keys", "[integration]") { + DestinationSdkImpl service; + + const std::string table_name = + "multikey_table" + std::to_string(Catch::rngSeed()); + auto token = std::getenv("motherduck_token"); + REQUIRE(token); + + { + // Create Table + ::fivetran_sdk::CreateTableRequest request; + (*request.mutable_configuration())["motherduck_token"] = token; + (*request.mutable_configuration())["motherduck_database"] = "fivetran_test"; + request.mutable_table()->set_name(table_name); + auto col1 = request.mutable_table()->add_columns(); + col1->set_name("id1"); + col1->set_type(::fivetran_sdk::DataType::INT); + col1->set_primary_key(true); + auto col2 = request.mutable_table()->add_columns(); + col2->set_name("id2"); + col2->set_type(::fivetran_sdk::DataType::INT); + col2->set_primary_key(true); + + ::fivetran_sdk::CreateTableResponse response; + auto status = service.CreateTable(nullptr, &request, &response); + + INFO(status.error_message()); + REQUIRE(status.ok()); + } + + { + // Describe the created table + ::fivetran_sdk::DescribeTableRequest request; + (*request.mutable_configuration())["motherduck_token"] = token; + (*request.mutable_configuration())["motherduck_database"] = "fivetran_test"; + request.set_table_name(table_name); + + { + ::fivetran_sdk::DescribeTableResponse response; + auto status = service.DescribeTable(nullptr, &request, &response); + + INFO(status.error_message()); + REQUIRE(status.ok()); + REQUIRE(response.table().columns().size() == 2); + } + } } \ No newline at end of file