Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[DatabricksConnector]: connector to connect to Databricks on Cloud #580

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/from_databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Example of using PandasAI with a Snowflake"""

gventuri marked this conversation as resolved.
Show resolved Hide resolved
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
from pandasai.connectors import DatabricksConnector

databricks_connector = DatabricksConnector(
config={
"host": "adb-*****.azuredatabricks.net",
"database": "default",
"token": "dapidfd412321",
"port": 443,
"table": "loan_payments_data",
"httpPath": "/sql/1.0/warehouses/213421312",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["loan_status", "=", "PAIDOFF"],
],
}
gventuri marked this conversation as resolved.
Show resolved Hide resolved
)

llm = OpenAI("OPEN_API_KEY")
gventuri marked this conversation as resolved.
Show resolved Hide resolved
df = SmartDataframe(databricks_connector, config={"llm": llm})

response = df.chat("How many people from the United states?")
print(response)
2 changes: 2 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .snowflake import SnowFlakeConnector
from .databricks import DatabricksConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = [
Expand All @@ -16,4 +17,5 @@
"PostgreSQLConnector",
"YahooFinanceConnector",
"SnowFlakeConnector",
"DatabricksConnector",
]
11 changes: 11 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ class SnowFlakeConnectorConfig(SQLBaseConnectorConfig):
warehouse: str


class DatabricksConnectorConfig(SQLBaseConnectorConfig):
"""
Connector configuration for DataBricks.
"""

host: str
port: int
token: str
httpPath: str


class BaseConnector(ABC):
"""
Base connector class to be extended by all connectors.
Expand Down
69 changes: 69 additions & 0 deletions pandasai/connectors/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Databricks Connector to connects you to your Databricks SQL Warhouse on
Azure, AWS and GCP
"""

import os
from .base import BaseConnectorConfig, DatabricksConnectorConfig
from sqlalchemy import create_engine
from typing import Union
from .sql import SQLConnector


class DatabricksConnector(SQLConnector):
"""
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class docstring seems to be copied from the SnowFlake connector and not updated. It should reflect that this is a Databricks connector.

-    """
-    SnowFlake connectors are used to connect to SnowFlake Data Cloud.
-    """
+    """
+    Databricks connectors are used to connect to Databricks SQL Warehouse on Azure, AWS, and GCP.
+    """


def __init__(self, config: DatabricksConnectorConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This accepts also a dictionary, Union[DatabricksConnectorConfig, dict]. In that case, we need to convert it to a DatabricksConnectorConfig.

"""
Initialize the SnowFlake connector with the given configuration.

Args:
config (ConnectorConfig): The configuration for the SnowFlake connector.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, the docstring for the __init__ method refers to SnowFlake instead of Databricks. This needs to be corrected.

-        Initialize the SnowFlake connector with the given configuration.
+        Initialize the Databricks connector with the given configuration.

-            config (ConnectorConfig): The configuration for the SnowFlake connector.
+            config (DatabricksConnectorConfig): The configuration for the Databricks connector.

config["dialect"] = "databricks"

if "token" not in config and os.getenv("DATABRICKS_TOKEN"):
config["token"] = os.getenv("DATABRICKS_TOKEN")
if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"):
config["database"] = os.getenv("SNOWFLAKE_DATABASE")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The environment variable being checked here is for Snowflake, not Databricks. This should be corrected to check for a Databricks-specific environment variable.

-        if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"):
-            config["database"] = os.getenv("SNOWFLAKE_DATABASE")
+        if "database" not in config and os.getenv("DATABRICKS_DATABASE"):
+            config["database"] = os.getenv("DATABRICKS_DATABASE")

if "host" not in config and os.getenv("DATABRICKS_HOST"):
config["host"] = os.getenv("DATABRICKS_HOST")
if "port" not in config and os.getenv("DATABRICKS_PORT"):
config["port"] = os.getenv("DATABRICKS_PORT")
if "httpPath" not in config and os.getenv("DATABRICKS_HTTP_PATH"):
config["httpPath"] = os.getenv("DATABRICKS_HTTP_PATH")

super().__init__(config)

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
return DatabricksConnectorConfig(**config)

def _init_connection(self, config: DatabricksConnectorConfig):
"""
Initialize Database Connection

Args:
config (SQLConnectorConfig): Configurations to load database

"""
self._engine = create_engine(
f"{config.dialect}://token:{config.token}@{config.host}:{config.port}?http_path={config.httpPath}"
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connection string is being created without any sanitization or encoding of the input parameters. This could potentially lead to security issues such as SQL injection attacks if the config parameters are user-supplied. Consider using a safer method to construct the connection string, such as URL encoding the parameters or using a library function designed for this purpose.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai, could you make an example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gventuri, sure! Here's an example of how you could safely construct the connection string using URL encoding:

from urllib.parse import quote_plus

class DatabricksConnector(SQLConnector):
    # ...

    def _init_connection(self, config: DatabricksConnectorConfig):
        """
        Initialize Database Connection

        Args:
            config (DatabricksConnectorConfig): Configurations to load database

        """
        safe_token = quote_plus(config.token)
        safe_host = quote_plus(config.host)
        safe_port = quote_plus(config.port)
        safe_httpPath = quote_plus(config.httpPath)

        self._engine = create_engine(
            f"{config.dialect}://token:{safe_token}@{safe_host}:{safe_port}?http_path={safe_httpPath}"
        )

        self._connection = self._engine.connect()

In this example, I've used Python's built-in urllib.parse.quote_plus function to URL encode the input parameters. This will ensure that any special characters in the parameters are properly escaped, preventing potential SQL injection attacks.

)

self._connection = self._engine.connect()

def __repr__(self):
"""
Return the string representation of the SnowFlake connector.

Returns:
str: The string representation of the SnowFlake connector.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for the __repr__ method also refers to SnowFlake instead of Databricks. This needs to be corrected.

-        Return the string representation of the SnowFlake connector.
+        Return the string representation of the Databricks connector.

-        str: The string representation of the SnowFlake connector.
+        str: The string representation of the Databricks connector.

return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"token={self._config.token} "
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exposing sensitive information like tokens in the string representation of the object can be a security risk. If this object is ever logged or printed in a debugging statement, it could expose the token to unauthorized users. Consider removing the token from the string representation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem let's not log the sensitive info!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like humans are talking. I'll be quiet for now. If you need me again, just tag @coderabbitai in a new comment.

)
14 changes: 2 additions & 12 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,7 @@ def rows_count(self):
)

# Run a SQL query to get the number of rows
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
"WHERE table_name = :table_name"
).bindparams(table_name=self._config.table)
query = sql.text(f"SELECT COUNT(*) FROM {self._config.table}")

# Return the number of rows
self._rows_count = self._connection.execute(query).fetchone()[0]
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new hunk simplifies the SQL query for row count, which is a good change. However, it assumes that the table name in the configuration does not contain any special characters or reserved words that might need to be escaped. This could lead to SQL injection vulnerabilities if the table name is user-supplied and not properly sanitized. Consider using bind parameters as in the old hunk to mitigate this risk.

- query = sql.text(f"SELECT COUNT(*) FROM {self._config.table}")
+ query = sql.text("SELECT COUNT(*) FROM :table_name").bindparams(table_name=self._config.table)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this could create a potential attack. We are using bindparams to validate the table name. Also the suggestion from @coderabbitai might work!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gventuri: I'm glad you agree with the suggestion. Using bind parameters is indeed a good practice to prevent SQL injection attacks. It ensures that the table name is properly escaped, preventing malicious input from being executed as part of the SQL query. Here's how you can modify your code:

query = sql.text("SELECT COUNT(*) FROM :table_name").bindparams(table_name=self._config.table)

This way, SQLAlchemy will handle the proper escaping of the table name, making your code safer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem: You're correct. My apologies for the confusion. SQLAlchemy's bindparams is indeed used to bind parameters, not table names or column names.

To safely include a table name in a query, we need to use SQLAlchemy's Table construct which provides proper quoting around identifiers. Here's how you can modify your code:

from sqlalchemy import Table, select, MetaData

metadata = MetaData()
table = Table(self._config.table, metadata)
query = select([func.count()]).select_from(table)

self._rows_count = self._connection.execute(query).scalar()

This way, SQLAlchemy will handle the proper quoting of the table name, making your code safer.

Expand All @@ -307,14 +304,7 @@ def columns_count(self):
f"{self._config.dialect}"
)

# Run a SQL query to get the number of columns
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
f"WHERE table_name = '{self._config.table}'"
)

# Return the number of columns
self._columns_count = self._connection.execute(query).fetchone()[0]
self._columns_count = len(self.head().columns)
return self._columns_count
gventuri marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines 304 to 308
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The columns_count method now uses the len function on the columns returned by the head method instead of running a SQL query. This change could improve performance by reducing the number of database queries. However, make sure that the head method always returns the correct and complete set of columns.


def _get_column_hash(self, include_additional_filters: bool = False):
Expand Down
Loading