diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 01780f0efe2..ac75d747330 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -976,8 +976,23 @@ SQLRETURN SQLNativeSql(SQLHDBC conn, SQLWCHAR* in_statement_text, << ", buffer_length: " << buffer_length << ", out_statement_text_length: " << static_cast(out_statement_text_length); - // GH-47723 TODO: Implement SQLNativeSql - return SQL_INVALID_HANDLE; + + using ODBC::GetAttributeSQLWCHAR; + using ODBC::ODBCConnection; + using ODBC::SqlWcharToString; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + const bool is_length_in_bytes = false; + + ODBCConnection* connection = reinterpret_cast(conn); + Diagnostics& diagnostics = connection->GetDiagnostics(); + + std::string in_statement_str = + SqlWcharToString(in_statement_text, in_statement_text_length); + + return GetAttributeSQLWCHAR(in_statement_str, is_length_in_bytes, out_statement_text, + buffer_length, out_statement_text_length, diagnostics); + }); } SQLRETURN SQLDescribeCol(SQLHSTMT stmt, SQLUSMALLINT column_number, SQLWCHAR* column_name, diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt index 4bc240637e7..cf3e15451d9 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test odbc_test_suite.cc odbc_test_suite.h connection_test.cc + statement_test.cc # Enable Protobuf cleanup after test execution # GH-46889: move protobuf_test_util to a more common location ../../../../engine/substrait/protobuf_test_util.cc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc new file mode 100644 index 00000000000..9d6d42c4a11 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +#include +#include + +namespace arrow::flight::sql::odbc { + +template +class StatementTest : public T {}; + +class StatementMockTest : public FlightSQLODBCMockTestBase {}; +class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(StatementTest, TestTypes); + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) { + SQLWCHAR buf[1024]; + SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + std::wstring expected_string = std::wstring(input_str); + + ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, input_char_len, buf, + buf_char_len, &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); + + // returned length is in characters + std::wstring returned_string(buf, buf + output_char_len); + + EXPECT_EQ(expected_string, returned_string); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) { + SQLWCHAR buf[1024]; + SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + std::wstring expected_string = std::wstring(input_str); + + ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, SQL_NTS, buf, buf_char_len, + &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); + + // returned length is in characters + std::wstring returned_string(buf, buf + output_char_len); + + EXPECT_EQ(expected_string, returned_string); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputStringLength) { + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + std::wstring expected_string = std::wstring(input_str); + + ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, input_char_len, nullptr, 0, + &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); + + ASSERT_EQ(SQL_SUCCESS, + SQLNativeSql(this->conn, input_str, SQL_NTS, nullptr, 0, &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) { + const SQLINTEGER small_buf_size_in_char = 11; + SQLWCHAR small_buf[small_buf_size_in_char]; + SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + + // Create expected return string based on buf size + SQLWCHAR expected_string_buf[small_buf_size_in_char]; + wcsncpy(expected_string_buf, input_str, 10); + expected_string_buf[10] = L'\0'; + std::wstring expected_string(expected_string_buf, + expected_string_buf + small_buf_size_in_char); + + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLNativeSql(this->conn, input_str, input_char_len, small_buf, + small_buf_char_len, &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState01004); + + // Returned text length represents full string char length regardless of truncation + EXPECT_EQ(input_char_len, output_char_len); + + std::wstring returned_string(small_buf, small_buf + small_buf_char_len); + + EXPECT_EQ(expected_string, returned_string); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) { + SQLWCHAR buf[1024]; + SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + + ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, nullptr, input_char_len, buf, + buf_char_len, &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY009); + + ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, nullptr, SQL_NTS, buf, buf_char_len, + &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY009); + + ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, input_str, -100, buf, buf_char_len, + &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY090); +} + +} // namespace arrow::flight::sql::odbc