Skip to content

Commit

Permalink
fix(sql): allow paramerized query through sql sanitization (#1576)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem authored Jan 31, 2025
1 parent 0c6738b commit f667367
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 4 deletions.
19 changes: 16 additions & 3 deletions extensions/connectors/sql/pandasai_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional

import pandas as pd
Expand All @@ -17,7 +18,11 @@ def load_from_mysql(
database=connection_info.database,
port=connection_info.port,
)
return pd.read_sql(query, conn, params=params)
# Suppress warnings of SqlAlchemy
# TODO - Later can be removed when SqlAlchemy is to used
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return pd.read_sql(query, conn, params=params)


def load_from_postgres(
Expand All @@ -32,7 +37,11 @@ def load_from_postgres(
dbname=connection_info.database,
port=connection_info.port,
)
return pd.read_sql(query, conn, params=params)
# Suppress warnings of SqlAlchemy
# TODO - Later can be removed when SqlAlchemy is to used
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return pd.read_sql(query, conn, params=params)


def load_from_cockroachdb(
Expand All @@ -47,7 +56,11 @@ def load_from_cockroachdb(
dbname=connection_info.database,
port=connection_info.port,
)
return pd.read_sql(query, conn, params=params)
# Suppress warnings of SqlAlchemy
# TODO - Later can be removed when SqlAlchemy is to used
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return pd.read_sql(query, conn, params=params)


__all__ = [
Expand Down
6 changes: 6 additions & 0 deletions pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
connection_info, formatted_query, params
)
return self._apply_transformations(dataframe)

except ModuleNotFoundError as e:
raise ImportError(
f"{source_type.capitalize()} connector not found. Please install the pandasai_sql[{source_type}] library, e.g. `pip install pandasai_sql[{source_type}]`."
) from e

except Exception as e:
raise RuntimeError(
f"Failed to execute query for '{source_type}' with: {formatted_query}"
Expand Down
8 changes: 7 additions & 1 deletion pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,14 @@ def is_sql_query_safe(query: str) -> bool:
r"--",
r"/\*.*\*/", # Block comments and inline comments
]

placeholder = "___PLACEHOLDER___" # Temporary placeholder for params

# Replace '%s' (MySQL, Psycopg2) with a unique placeholder
temp_query = query.replace("%s", placeholder)

# Parse the query to extract its structure
parsed = sqlglot.parse_one(query)
parsed = sqlglot.parse_one(temp_query)

# Ensure the main query is SELECT
if parsed.key.upper() != "SELECT":
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/data_loader/test_sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,31 @@ def test_mysql_safe_query(self, mysql_schema):

assert isinstance(result, DataFrame)
mock_sql_query.assert_called_once_with("select * from users")

def test_mysql_malicious_with_no_import(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
with patch(
"pandasai.data_loader.sql_loader.is_sql_query_safe"
) as mock_sql_query, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
) as mock_loader_function:
mocked_exec_function = MagicMock()
mock_df = DataFrame(
pd.DataFrame(
{
"email": ["[email protected]"],
"first_name": ["John"],
"timestamp": [pd.Timestamp.now()],
}
)
)
mocked_exec_function.return_value = mock_df

mock_exec_function = MagicMock()
mock_loader_function.return_value = mock_exec_function
mock_exec_function.side_effect = ModuleNotFoundError("Error")
loader = SQLDatasetLoader(mysql_schema, "test/users")
mock_sql_query.return_value = True
logging.debug("Loading schema from dataset path: %s", loader)
with pytest.raises(ImportError):
loader.execute_query("select * from users")
4 changes: 4 additions & 0 deletions tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def test_safe_query_with_subquery(self):
query
) # Safe query with subquery, no dangerous keyword

def test_safe_query_with_query_params(self):
query = "SELECT * FROM (SELECT * FROM heart_data) AS filtered_data LIMIT %s OFFSET %s"
assert is_sql_query_safe(query)


if __name__ == "__main__":
unittest.main()

0 comments on commit f667367

Please sign in to comment.