Skip to content

Commit

Permalink
feat(go/adbc/driver/snowflake): add query tag option (#2484)
Browse files Browse the repository at this point in the history
This lets you identify particular queries in the query history.

Fixes #1934.
  • Loading branch information
lidavidm authored Jan 28, 2025
1 parent c4d2dab commit f37fd5c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
41 changes: 41 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2253,3 +2253,44 @@ func TestJSONUnmarshal(t *testing.T) {
}
}
}

func (suite *SnowflakeTests) TestQueryTag() {
u, err := uuid.NewV7()
suite.Require().NoError(err)
tag := u.String()
suite.Require().NoError(suite.stmt.SetOption(driver.OptionStatementQueryTag, tag))

val, err := suite.stmt.(adbc.GetSetOptions).GetOption(driver.OptionStatementQueryTag)
suite.Require().NoError(err)
suite.Require().Equal(tag, val)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT 1"))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.True(rdr.Next())
suite.False(rdr.Next())
suite.Require().NoError(rdr.Err())

// Unset tag
suite.Require().NoError(suite.stmt.SetOption(driver.OptionStatementQueryTag, ""))

suite.Require().NoError(suite.stmt.SetSqlQuery(fmt.Sprintf(`
SELECT query_text
FROM table(information_schema.query_history())
WHERE query_tag = '%s'
ORDER BY start_time;
`, tag)))
rdr, n, err = suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.True(rdr.Next())
result := rdr.Record()
suite.Require().Equal("SELECT 1", result.Column(0).(*array.String).Value(0))
suite.False(rdr.Next())
suite.Require().NoError(rdr.Err())
}
23 changes: 23 additions & 0 deletions go/adbc/driver/snowflake/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
)

const (
OptionStatementQueryTag = "adbc.snowflake.statement.query_tag"
OptionStatementQueueSize = "adbc.rpc.result_queue_size"
OptionStatementPrefetchConcurrency = "adbc.snowflake.rpc.prefetch_concurrency"
OptionStatementIngestWriterConcurrency = "adbc.snowflake.statement.ingest_writer_concurrency"
Expand All @@ -54,11 +55,20 @@ type statement struct {
targetTable string
ingestMode string
ingestOptions *ingestOptions
queryTag string

bound arrow.Record
streamBind array.RecordReader
}

// setQueryContext applies the query tag if present.
func (st *statement) setQueryContext(ctx context.Context) context.Context {
if st.queryTag != "" {
ctx = gosnowflake.WithQueryTag(ctx, st.queryTag)
}
return ctx
}

// Close releases any relevant resources associated with this statement
// and closes it (particularly if it is a prepared statement).
//
Expand All @@ -82,6 +92,10 @@ func (st *statement) Close() error {
}

func (st *statement) GetOption(key string) (string, error) {
switch key {
case OptionStatementQueryTag:
return st.queryTag, nil
}
return "", adbc.Error{
Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
Code: adbc.StatusNotFound,
Expand Down Expand Up @@ -186,6 +200,9 @@ func (st *statement) SetOption(key string, val string) error {
}
}
return st.SetOptionInt(key, int64(size))
case OptionStatementQueryTag:
st.queryTag = val
return nil
case OptionUseHighPrecision:
switch val {
case adbc.OptionValueEnabled:
Expand Down Expand Up @@ -449,6 +466,8 @@ func (st *statement) executeIngest(ctx context.Context) (int64, error) {
//
// This invalidates any prior result sets on this statement.
func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int64, error) {
ctx = st.setQueryContext(ctx)

if st.targetTable != "" {
n, err := st.executeIngest(ctx)
return nil, n, err
Expand Down Expand Up @@ -500,6 +519,8 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6
// ExecuteUpdate executes a statement that does not generate a result
// set. It returns the number of rows affected if known, otherwise -1.
func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
ctx = st.setQueryContext(ctx)

if st.targetTable != "" {
return st.executeIngest(ctx)
}
Expand Down Expand Up @@ -558,6 +579,8 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {

// ExecuteSchema gets the schema of the result set of a query without executing it.
func (st *statement) ExecuteSchema(ctx context.Context) (*arrow.Schema, error) {
ctx = st.setQueryContext(ctx)

if st.targetTable != "" {
return nil, adbc.Error{
Msg: "cannot execute schema for ingestion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class StatementOptions(enum.Enum):
#: Number of concurrent streams being prefetched for a result set.
#: Defaults to 10.
PREFETCH_CONCURRENCY = "adbc.snowflake.rpc.prefetch_concurrency"
#: An identifier for a query/queries that can be used to find the query in
#: the query history. Use a blank string to unset the tag.
QUERY_TAG = "adbc.snowflake.statement.query_tag"
#: Number of parquet files to write in parallel for bulk ingestion
#: Defaults to NumCPU
INGEST_WRITER_CONCURRENCY = "adbc.snowflake.statement.ingest_writer_concurrency"
Expand Down

0 comments on commit f37fd5c

Please sign in to comment.