Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): Removing SQL injection to get table na…
Browse files Browse the repository at this point in the history
…me with special character for getObjectsTables (#1338)

**Description:**

> GetObjects API was inconsistent getting table with special character
and making the conditions case-insensitive.

**Solution:**

> Passing table names as query argument and avoiding SQL Injection

**Testing:**

> Added test in DriverTest

Fixes #1225

---------

Co-authored-by: Anitha <[email protected]>
Co-authored-by: David Li <[email protected]>
  • Loading branch information
3 people authored Jan 3, 2024
1 parent a3c0b1d commit 64c19eb
Show file tree
Hide file tree
Showing 5 changed files with 756 additions and 202 deletions.
105 changes: 98 additions & 7 deletions csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
Expand All @@ -16,6 +16,7 @@
*/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Apache.Arrow.Adbc.Tests.Metadata;
Expand Down Expand Up @@ -110,10 +111,7 @@ public void CanExecuteUpdate()
for (int i = 0; i < queries.Length; i++)
{
string query = queries[i];
using AdbcStatement statement = _connection.CreateStatement();
statement.SqlQuery = query;

UpdateResult updateResult = statement.ExecuteUpdate();
UpdateResult updateResult = ExecuteUpdateStatement(query);

Assert.Equal(expectedResults[i], updateResult.AffectedRows);
}
Expand Down Expand Up @@ -279,17 +277,66 @@ public void CanGetObjectsAll()
{
IEnumerable<AdbcColumn> highPrecisionColumns = columns.Where(c => c.XdbcTypeName == "NUMBER");

if(highPrecisionColumns.Count() > 0)
if (highPrecisionColumns.Count() > 0)
{
// ensure they all are coming back as XdbcDataType_XDBC_DECIMAL because they are Decimal128
short XdbcDataType_XDBC_DECIMAL = 3;
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
int count = invalidHighPrecisionColumns.Count();
Assert.True(count == 0, $"There are {count} columns that do not map to the correct XdbcSqlDataType when UseHighPrecision=true");
}
}
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a Special Character.
/// </summary>
[SkippableTheory, Order(3)]
[InlineData(@"ADBCDEMO_DB",@"PUBLIC","MyIdentifier")]
[InlineData(@"ADBCDEMO'DB", @"PUBLIC'SCHEMA","my.identifier")]
[InlineData(@"ADBCDEM""DB", @"PUBLIC""SCHEMA", "my.identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "my identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My 'Identifier'")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "3rd_identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "$Identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My ^Identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My ^Ident~ifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"My\^Ident~ifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "идентификатор")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"ADBCTest_""ALL""TYPES")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"ADBC\TEST""\TAB_""LE")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "ONE")]
public void CanGetObjectsTablesWithSpecialCharacter(string databaseName, string schemaName, string tableName)
{
CreateDatabaseAndTable(databaseName, schemaName, tableName);

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.Tables,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: tableName,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcTable> tables = catalogs
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault()
.Where(s => string.Equals(s.Name, schemaName))
.Select(s => s.Tables)
.FirstOrDefault();

AdbcTable table = tables.FirstOrDefault();

Assert.True(table != null, "table should not be null");
Assert.Equal(tableName, table.Name, true);
DropDatabaseAndTable(databaseName, schemaName, tableName);
}

/// <summary>
/// Validates if the driver can call GetTableSchema.
/// </summary>
Expand Down Expand Up @@ -354,6 +401,50 @@ public void CanExecuteQuery()
Tests.DriverTests.CanExecuteQuery(queryResult, _testConfiguration.ExpectedResultsCount);
}

private void CreateDatabaseAndTable(string databaseName, string schemaName, string tableName)
{
databaseName = databaseName.Replace("\"", "\"\"");
schemaName = schemaName.Replace("\"", "\"\"");
tableName = tableName.Replace("\"", "\"\"");

string createDatabase = string.Format("CREATE DATABASE IF NOT EXISTS \"{0}\"", databaseName);
ExecuteUpdateStatement(createDatabase);

string createSchema = string.Format("CREATE SCHEMA IF NOT EXISTS \"{0}\".\"{1}\"", databaseName, schemaName);
ExecuteUpdateStatement(createSchema);

string fullyQualifiedTableName = string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
string createTableStatement = string.Format("CREATE OR REPLACE TABLE {0} (INDEX INT)", fullyQualifiedTableName);
ExecuteUpdateStatement(createTableStatement);

}

private void DropDatabaseAndTable(string databaseName, string schemaName, string tableName)
{
tableName = tableName.Replace("\"", "\"\"");
schemaName = schemaName.Replace("\"", "\"\"");
databaseName = databaseName.Replace("\"", "\"\"");

string fullyQualifiedTableName = string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
string createTableStatement = string.Format("DROP TABLE IF EXISTS {0} ", fullyQualifiedTableName);
ExecuteUpdateStatement(createTableStatement);

string createSchema = string.Format("DROP SCHEMA IF EXISTS \"{0}\".\"{1}\"", databaseName, schemaName);
ExecuteUpdateStatement(createSchema);

string createDatabase = string.Format("DROP DATABASE IF EXISTS \"{0}\"", databaseName);
ExecuteUpdateStatement(createDatabase);

}

private UpdateResult ExecuteUpdateStatement(string query)
{
using AdbcStatement statement = _connection.CreateStatement();
statement.SqlQuery = query;
UpdateResult updateResult = statement.ExecuteUpdate();
return updateResult;
}

private static string GetPartialNameForPatternMatch(string name)
{
if (string.IsNullOrEmpty(name) || name.Length == 1) return name;
Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/flightsql/flightsql_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info
}

// Helper function to build up a map of catalogs to DB schemas
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string) (result map[string][]string, err error) {
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
Expand Down Expand Up @@ -588,7 +588,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
return
}

func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (result internal.SchemaToTableInfo, err error) {
func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas {
return
}
Expand Down
21 changes: 17 additions & 4 deletions go/adbc/driver/internal/shared_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package internal

import (
"context"
"database/sql"
"regexp"
"strconv"
"strings"
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v14/arrow"
Expand All @@ -38,8 +40,18 @@ type TableInfo struct {
Schema *arrow.Schema
}

type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string) (map[string][]string, error)
type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string) (map[CatalogAndSchema][]TableInfo, error)
type Metadata struct {
Created time.Time
ColName, DataType string
Dbname, Kind, Schema, TblName, TblType, IdentGen, IdentIncrement, Comment sql.NullString
OrdinalPos int
NumericPrec, NumericPrecRadix, NumericScale, DatetimePrec sql.NullInt16
IsNullable, IsIdent bool
CharMaxLength, CharOctetLength sql.NullInt32
}

type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []Metadata) (map[string][]string, error)
type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []Metadata) (map[CatalogAndSchema][]TableInfo, error)
type SchemaToTableInfo = map[CatalogAndSchema][]TableInfo

// Helper function that compiles a SQL-style pattern (%, _) to a regex
Expand Down Expand Up @@ -87,6 +99,7 @@ type GetObjects struct {
builder *array.RecordBuilder
schemaLookup map[string][]string
tableLookup map[CatalogAndSchema][]TableInfo
MetadataRecords []Metadata
catalogPattern *regexp.Regexp
columnNamePattern *regexp.Regexp

Expand Down Expand Up @@ -123,13 +136,13 @@ type GetObjects struct {
}

func (g *GetObjects) Init(mem memory.Allocator, getObj GetObjDBSchemasFn, getTbls GetObjTablesFn) error {
if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog, g.DbSchema); err != nil {
if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.MetadataRecords); err != nil {
return err
} else {
g.schemaLookup = catalogToDbSchemas
}

if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.TableName, g.ColumnName, g.TableType); err != nil {
if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.TableName, g.ColumnName, g.TableType, g.MetadataRecords); err != nil {
return err
} else {
g.tableLookup = tableLookup
Expand Down
Loading

0 comments on commit 64c19eb

Please sign in to comment.