Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/src/arrow/flight/sql/odbc/entry_points.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ SQLRETURN SQL_API SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle,
buffer_length, text_length_ptr);
}

#if defined(__APPLE__)
// macOS ODBC Driver Manager doesn't map SQLError to SQLGetDiagRec, so we need to
// implement SQLError for macOS.
// on Windows, SQLError mapping implemented by Driver Manager is preferred.
SQLRETURN SQL_API SQLError(SQLHENV env, SQLHDBC conn, SQLHSTMT stmt, SQLWCHAR* sql_state,
SQLINTEGER* native_error_ptr, SQLWCHAR* message_text,
SQLSMALLINT buffer_length, SQLSMALLINT* text_length_ptr) {
return arrow::flight::sql::odbc::SQLError(env, conn, stmt, sql_state, native_error_ptr,
message_text, buffer_length, text_length_ptr);
}
#endif // __APPLE__

SQLRETURN SQL_API SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr,
SQLINTEGER buffer_len, SQLINTEGER* str_len_ptr) {
return arrow::flight::sql::odbc::SQLGetEnvAttr(env, attr, value_ptr, buffer_len,
Expand Down
51 changes: 50 additions & 1 deletion cpp/src/arrow/flight/sql/odbc/odbc_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,56 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) {
return SQL_ERROR;
}

#if defined(__APPLE__)
SQLRETURN SQLError(SQLHENV env, SQLHDBC conn, SQLHSTMT stmt, SQLWCHAR* sql_state,
SQLINTEGER* native_error_ptr, SQLWCHAR* message_text,
SQLSMALLINT buffer_length, SQLSMALLINT* text_length_ptr) {
ARROW_LOG(DEBUG) << "SQLError called with env: " << env << ", conn: " << conn
<< ", stmt: " << stmt
<< ", sql_state: " << static_cast<const void*>(sql_state)
<< ", native_error_ptr: " << static_cast<const void*>(native_error_ptr)
<< ", message_text: " << static_cast<const void*>(message_text)
<< ", buffer_length: " << buffer_length
<< ", text_length_ptr: " << static_cast<const void*>(text_length_ptr);

SQLSMALLINT handle_type;
SQLHANDLE handle;

if (env) {
handle_type = SQL_HANDLE_ENV;
handle = static_cast<SQLHANDLE>(env);
} else if (conn) {
handle_type = SQL_HANDLE_DBC;
handle = static_cast<SQLHANDLE>(conn);
} else if (stmt) {
handle_type = SQL_HANDLE_STMT;
handle = static_cast<SQLHANDLE>(stmt);
} else {
return static_cast<SQLRETURN>(SQL_INVALID_HANDLE);
}

// Use the last record
SQLINTEGER diag_number;
SQLSMALLINT diag_number_length;

SQLRETURN ret = arrow::flight::sql::odbc::SQLGetDiagField(
handle_type, handle, 0, SQL_DIAG_NUMBER, &diag_number, sizeof(SQLINTEGER), 0);
if (ret != SQL_SUCCESS) {
return ret;
}

if (diag_number == 0) {
return SQL_NO_DATA;
}

SQLSMALLINT rec_number = static_cast<SQLSMALLINT>(diag_number);

return arrow::flight::sql::odbc::SQLGetDiagRec(
handle_type, handle, rec_number, sql_state, native_error_ptr, message_text,
buffer_length, text_length_ptr);
}
#endif // __APPLE__

inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
SQLSMALLINT* string_length_ptr, bool is_unicode) {
const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize() : sizeof(char);
Expand Down Expand Up @@ -736,7 +786,6 @@ SQLRETURN SQLGetConnectAttr(SQLHDBC conn, SQLINTEGER attribute, SQLPOINTER value
<< ", attribute: " << attribute << ", value_ptr: " << value_ptr
<< ", buffer_length: " << buffer_length << ", string_length_ptr: "
<< static_cast<const void*>(string_length_ptr);

using ODBC::ODBCConnection;

return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/sql/odbc/odbc_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ namespace arrow::flight::sql::odbc {
SQLHANDLE* result);
[[nodiscard]] SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle);
[[nodiscard]] SQLRETURN SQLFreeStmt(SQLHSTMT stmt, SQLUSMALLINT option);
#if defined(__APPLE__)
[[nodiscard]] SQLRETURN SQLError(SQLHENV env, SQLHDBC conn, SQLHSTMT stmt,
SQLWCHAR* sql_state, SQLINTEGER* native_error_ptr,
SQLWCHAR* message_text, SQLSMALLINT buffer_length,
SQLSMALLINT* text_length_ptr);
#endif // __APPLE__
[[nodiscard]] SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
SQLSMALLINT rec_number,
SQLSMALLINT diag_identifier,
Expand Down
100 changes: 67 additions & 33 deletions cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ using TestTypesHandle = ::testing::Types<FlightSQLOdbcEnvConnHandleMockTestBase,
FlightSQLOdbcEnvConnHandleRemoteTestBase>;
TYPED_TEST_SUITE(ErrorsHandleTest, TestTypesHandle);

using ODBC::SqlWcharToString;

TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagFieldWForConnectFailure) {
// Invalid connect string
std::string connect_str = this->GetInvalidConnectionString();
Expand Down Expand Up @@ -90,9 +92,12 @@ TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagFieldWForConnectFailure) {
SQLWCHAR message_text[kOdbcBufferSize];
SQLSMALLINT message_text_length;

EXPECT_EQ(SQL_SUCCESS,
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT,
message_text, kOdbcBufferSize, &message_text_length));
SQLRETURN ret =
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT,
message_text, kOdbcBufferSize, &message_text_length);

// dependent on the size of the message it could output SQL_SUCCESS_WITH_INFO
EXPECT_TRUE(ret == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO);

EXPECT_GT(message_text_length, 100);

Expand All @@ -114,10 +119,9 @@ TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagFieldWForConnectFailure) {
EXPECT_EQ(
SQL_SUCCESS,
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_SQLSTATE, sql_state,
sql_state_size * arrow::flight::sql::odbc::GetSqlWCharSize(),
&sql_state_length));
sql_state_size * GetSqlWCharSize(), &sql_state_length));

EXPECT_EQ(std::wstring(L"28000"), std::wstring(sql_state));
EXPECT_EQ(kErrorState28000, SqlWcharToString(sql_state));
}

TYPED_TEST(ErrorsHandleTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS) {
Expand Down Expand Up @@ -156,6 +160,8 @@ TYPED_TEST(ErrorsHandleTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS)
EXPECT_GT(message_text_length, 100);
}

// iODBC does not support application allocated descriptors.
#ifndef __APPLE__
TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForDescriptorFailureFromDriverManager) {
SQLHDESC descriptor;

Expand Down Expand Up @@ -216,7 +222,7 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForDescriptorFailureFromDriverManager
SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_SQLSTATE, sql_state,
sql_state_size * GetSqlWCharSize(), &sql_state_length));

EXPECT_EQ(std::wstring(L"IM001"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateIM001, SqlWcharToString(sql_state));

// Free descriptor handle
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor));
Expand Down Expand Up @@ -245,13 +251,14 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagRecForDescriptorFailureFromDriverManager) {
EXPECT_EQ(0, native_error);

// API not implemented error from driver manager
EXPECT_EQ(std::wstring(L"IM001"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateIM001, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());

// Free descriptor handle
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor));
}
#endif // __APPLE__

TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagRecForConnectFailure) {
// Invalid connect string
Expand Down Expand Up @@ -282,7 +289,7 @@ TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagRecForConnectFailure) {

EXPECT_EQ(200, native_error);

EXPECT_EQ(std::wstring(L"28000"), std::wstring(sql_state));
EXPECT_EQ(kErrorState28000, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}
Expand All @@ -305,11 +312,17 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagRecInputData) {
nullptr, 0, nullptr));

// Invalid handle
#ifdef __APPLE__
// MacOS ODBC driver manager requires connection handle
EXPECT_EQ(SQL_INVALID_HANDLE,
SQLGetDiagRec(0, this->conn, 1, nullptr, nullptr, nullptr, 0, nullptr));
#else
EXPECT_EQ(SQL_INVALID_HANDLE,
SQLGetDiagRec(0, nullptr, 0, nullptr, nullptr, nullptr, 0, nullptr));
#endif // __APPLE__
}

TYPED_TEST(ErrorsTest, TestSQLErrorInputData) {
TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorInputData) {
// Test ODBC 2.0 API SQLError. Driver manager maps SQLError to SQLGetDiagRec.
// SQLError does not post diagnostic records for itself.

Expand All @@ -320,8 +333,13 @@ TYPED_TEST(ErrorsTest, TestSQLErrorInputData) {
EXPECT_EQ(SQL_NO_DATA, SQLError(nullptr, this->conn, nullptr, nullptr, nullptr, nullptr,
0, nullptr));

#ifdef __APPLE__
EXPECT_EQ(SQL_NO_DATA, SQLError(SQL_NULL_HENV, this->conn, this->stmt, nullptr, nullptr,
nullptr, 0, nullptr));
#else
EXPECT_EQ(SQL_NO_DATA, SQLError(nullptr, nullptr, this->stmt, nullptr, nullptr, nullptr,
0, nullptr));
#endif // __APPLE__

// Invalid handle
EXPECT_EQ(SQL_INVALID_HANDLE,
Expand All @@ -345,12 +363,12 @@ TYPED_TEST(ErrorsTest, TestSQLErrorEnvErrorFromDriverManager) {
ASSERT_EQ(SQL_SUCCESS, SQLError(this->env, nullptr, nullptr, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));

EXPECT_GT(message_length, 50);
EXPECT_GT(message_length, 40);

EXPECT_EQ(0, native_error);

// Function sequence error state from driver manager
EXPECT_EQ(std::wstring(L"HY010"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateHY010, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}
Expand All @@ -362,9 +380,8 @@ TYPED_TEST(ErrorsTest, TestSQLErrorConnError) {
// DM passes 512 as buffer length to SQLError.

// Attempt to set unsupported attribute
SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, nullptr);

ASSERT_EQ(SQL_ERROR, ret);
ASSERT_EQ(SQL_ERROR,
SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, nullptr));

SQLWCHAR sql_state[6] = {0};
SQLINTEGER native_error = 0;
Expand All @@ -378,7 +395,7 @@ TYPED_TEST(ErrorsTest, TestSQLErrorConnError) {
EXPECT_EQ(100, native_error);

// optional feature not supported error state
EXPECT_EQ(std::wstring(L"HYC00"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateHYC00, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}
Expand All @@ -399,14 +416,16 @@ TYPED_TEST(ErrorsTest, TestSQLErrorStmtError) {
SQLINTEGER native_error = 0;
SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0};
SQLSMALLINT message_length = 0;
ASSERT_EQ(SQL_SUCCESS, SQLError(nullptr, nullptr, this->stmt, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));
SQLRETURN ret = SQLError(nullptr, this->conn, this->stmt, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length);

EXPECT_TRUE(ret == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO);

EXPECT_GT(message_length, 70);

EXPECT_EQ(100, native_error);

EXPECT_EQ(std::wstring(L"HY000"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateHY000, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}
Expand Down Expand Up @@ -434,20 +453,21 @@ TYPED_TEST(ErrorsTest, TestSQLErrorStmtWarning) {
SQLINTEGER native_error = 0;
SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0};
SQLSMALLINT message_length = 0;
ASSERT_EQ(SQL_SUCCESS, SQLError(nullptr, nullptr, this->stmt, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));
ASSERT_EQ(SQL_SUCCESS,
SQLError(SQL_NULL_HENV, this->conn, this->stmt, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));

EXPECT_GT(message_length, 50);

EXPECT_EQ(1000100, native_error);

// Verify string truncation warning is reported
EXPECT_EQ(std::wstring(L"01004"), std::wstring(sql_state));
EXPECT_EQ(kErrorState01004, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}

TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorEnvErrorODBCVer2FromDriverManager) {
TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorEnvErrorFromDriverManager) {
// Test ODBC 2.0 API SQLError with ODBC ver 2.
// Known Windows Driver Manager (DM) behavior:
// When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512),
Expand All @@ -464,22 +484,34 @@ TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorEnvErrorODBCVer2FromDriverManager) {
ASSERT_EQ(SQL_SUCCESS, SQLError(this->env, nullptr, nullptr, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));

EXPECT_GT(message_length, 50);
EXPECT_GT(message_length, 40);

EXPECT_EQ(0, native_error);

// Function sequence error state from driver manager
EXPECT_EQ(std::wstring(L"S1010"), std::wstring(sql_state));
#ifdef _WIN32
// Windows Driver Manager returns S1010
EXPECT_EQ(kErrorStateS1010, SqlWcharToString(sql_state));
#else
// unix Driver Manager returns HY010
EXPECT_EQ(kErrorStateHY010, SqlWcharToString(sql_state));
#endif // _WIN32

EXPECT_FALSE(std::wstring(message).empty());
}

TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorConnErrorODBCVer2) {
#ifndef __APPLE__
TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorConnError) {
// Test ODBC 2.0 API SQLError with ODBC ver 2.
// Known Windows Driver Manager (DM) behavior:
// When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512),
// DM passes 512 as buffer length to SQLError.

// Known macOS Driver Manager (DM) behavior:
// Attempts to call SQLGetConnectOption without redirecting the API call to
// SQLGetConnectAttr. SQLGetConnectOption is not implemented as it is not required by
// macOS Excel.

// Attempt to set unsupported attribute
ASSERT_EQ(SQL_ERROR,
SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, nullptr));
Expand All @@ -496,12 +528,13 @@ TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorConnErrorODBCVer2) {
EXPECT_EQ(100, native_error);

// optional feature not supported error state. Driver Manager maps state to S1C00
EXPECT_EQ(std::wstring(L"S1C00"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateS1C00, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}
#endif // __APPLE__

TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtErrorODBCVer2) {
TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtError) {
// Test ODBC 2.0 API SQLError with ODBC ver 2.
// Known Windows Driver Manager (DM) behavior:
// When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512),
Expand All @@ -525,12 +558,12 @@ TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtErrorODBCVer2) {
EXPECT_EQ(100, native_error);

// Driver Manager maps error state to S1000
EXPECT_EQ(std::wstring(L"S1000"), std::wstring(sql_state));
EXPECT_EQ(kErrorStateS1000, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}

TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtWarningODBCVer2) {
TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtWarning) {
// Test ODBC 2.0 API SQLError.

std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;";
Expand All @@ -553,15 +586,16 @@ TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtWarningODBCVer2) {
SQLINTEGER native_error = 0;
SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0};
SQLSMALLINT message_length = 0;
ASSERT_EQ(SQL_SUCCESS, SQLError(nullptr, nullptr, this->stmt, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));
ASSERT_EQ(SQL_SUCCESS,
SQLError(SQL_NULL_HENV, this->conn, this->stmt, sql_state, &native_error,
message, SQL_MAX_MESSAGE_LENGTH, &message_length));

EXPECT_GT(message_length, 50);

EXPECT_EQ(1000100, native_error);

// Verify string truncation warning is reported
EXPECT_EQ(std::wstring(L"01004"), std::wstring(sql_state));
EXPECT_EQ(kErrorState01004, SqlWcharToString(sql_state));

EXPECT_FALSE(std::wstring(message).empty());
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ static constexpr std::string_view kErrorStateHY114 = "HY114";
static constexpr std::string_view kErrorStateHY118 = "HY118";
static constexpr std::string_view kErrorStateHYC00 = "HYC00";
static constexpr std::string_view kErrorStateIM001 = "IM001";
static constexpr std::string_view kErrorStateS1000 = "S1000";
static constexpr std::string_view kErrorStateS1002 = "S1002";
static constexpr std::string_view kErrorStateS1004 = "S1004";
static constexpr std::string_view kErrorStateS1010 = "S1010";
static constexpr std::string_view kErrorStateS1090 = "S1090";
static constexpr std::string_view kErrorStateS1C00 = "S1C00";

/// Verify ODBC Error State
void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle,
Expand Down
Loading