Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[DatabricksConnector]: connector to connect to Databricks on Cloud #580

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/from_databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Example of using PandasAI with a DataBricks"""

gventuri marked this conversation as resolved.
Show resolved Hide resolved
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
from pandasai.connectors import DatabricksConnector


databricks_connector = DatabricksConnector(
config={
"host": "adb-*****.azuredatabricks.net",
"database": "default",
"token": "dapidfd412321",
"port": 443,
"table": "loan_payments_data",
"httpPath": "/sql/1.0/warehouses/213421312",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["loan_status", "=", "PAIDOFF"],
],
}
gventuri marked this conversation as resolved.
Show resolved Hide resolved
)

llm = OpenAI("OPEN_API_KEY")
gventuri marked this conversation as resolved.
Show resolved Hide resolved
df = SmartDataframe(databricks_connector, config={"llm": llm})

response = df.chat("How many people from the United states?")
print(response)
2 changes: 2 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .snowflake import SnowFlakeConnector
from .databricks import DatabricksConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = [
Expand All @@ -16,4 +17,5 @@
"PostgreSQLConnector",
"YahooFinanceConnector",
"SnowFlakeConnector",
"DatabricksConnector",
]
30 changes: 30 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from abc import ABC, abstractmethod
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The os module is imported but not used in the provided hunk. If it's not used elsewhere in the file, consider removing this import to keep the code clean and efficient.

- import os

from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from pydantic import BaseModel
Expand Down Expand Up @@ -63,6 +64,17 @@ class SnowFlakeConnectorConfig(SQLBaseConnectorConfig):
warehouse: str


class DatabricksConnectorConfig(SQLBaseConnectorConfig):
"""
Connector configuration for DataBricks.
"""

host: str
port: int
token: str
httpPath: str


class BaseConnector(ABC):
"""
Base connector class to be extended by all connectors.
Expand Down Expand Up @@ -95,6 +107,24 @@ def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
"""
pass

def _populate_config_from_env(self, config: dict, envs_mapping: dict):
"""
Populate the configuration dictionary with values from environment variables
if not exists in the config.

Args:
config (dict): The configuration dictionary to be populated.

Returns:
dict: The populated configuration dictionary.
"""

for key, env_var in envs_mapping.items():
if key not in config and os.getenv(env_var):
config[key] = os.getenv(env_var)

return config

def _init_connection(self, config: BaseConnectorConfig):
"""
make connection to database
Expand Down
65 changes: 65 additions & 0 deletions pandasai/connectors/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Databricks Connector to connects you to your Databricks SQL Warhouse on
Azure, AWS and GCP
"""

from .base import BaseConnectorConfig, DatabricksConnectorConfig
from sqlalchemy import create_engine
from typing import Union
from .sql import SQLConnector


class DatabricksConnector(SQLConnector):
"""
Databricks connectors are used to connect to Databricks Data Cloud.
"""

def __init__(self, config: Union[DatabricksConnectorConfig, dict]):
"""
Initialize the Databricks connector with the given configuration.

Args:
config (ConnectorConfig): The configuration for the Databricks connector.
"""
config["dialect"] = "databricks"
if isinstance(config, dict):
env_vars = {
"token": "DATABRICKS_TOKEN",
"database": "DATABRICKS_DATABASE",
"host": "DATABRICKS_HOST",
"port": "DATABRICKS_PORT",
"httpPath": "DATABRICKS_HTTP_PATH",
}
config = self._populate_config_from_env(config, env_vars)

gventuri marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(config)

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
return DatabricksConnectorConfig(**config)

def _init_connection(self, config: DatabricksConnectorConfig):
"""
Initialize Database Connection

Args:
config (DatabricksConnectorConfig): Configurations to load database

"""
self._engine = create_engine(
f"{config.dialect}://token:{config.token}@{config.host}:{config.port}?http_path={config.httpPath}"
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connection string is being created without any sanitization or encoding of the input parameters. This could potentially lead to security issues such as SQL injection attacks if the config parameters are user-supplied. Consider using a safer method to construct the connection string, such as URL encoding the parameters or using a library function designed for this purpose.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai, could you make an example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gventuri, sure! Here's an example of how you could safely construct the connection string using URL encoding:

from urllib.parse import quote_plus

class DatabricksConnector(SQLConnector):
    # ...

    def _init_connection(self, config: DatabricksConnectorConfig):
        """
        Initialize Database Connection

        Args:
            config (DatabricksConnectorConfig): Configurations to load database

        """
        safe_token = quote_plus(config.token)
        safe_host = quote_plus(config.host)
        safe_port = quote_plus(config.port)
        safe_httpPath = quote_plus(config.httpPath)

        self._engine = create_engine(
            f"{config.dialect}://token:{safe_token}@{safe_host}:{safe_port}?http_path={safe_httpPath}"
        )

        self._connection = self._engine.connect()

In this example, I've used Python's built-in urllib.parse.quote_plus function to URL encode the input parameters. This will ensure that any special characters in the parameters are properly escaped, preventing potential SQL injection attacks.

)

self._connection = self._engine.connect()

def __repr__(self):
"""
Return the string representation of the Databricks connector.

Returns:
str: The string representation of the Databricks connector.
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)
Comment on lines +61 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __repr__ method is revealing potentially sensitive information such as the host and port. It's generally not a good practice to expose such details. Consider removing these from the representation.

-        return (
-            f"<{self.__class__.__name__} dialect={self._config.dialect} "
-            f"host={self._config.host} port={self._config.port} "
-            f"database={self._config.database} httpPath={str(self._config.httpPath)}"
-        )
+        return f"<{self.__class__.__name__} database={self._config.database}>"

28 changes: 12 additions & 16 deletions pandasai/connectors/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""

import os
import pandas as pd
from .base import BaseConnectorConfig, SnowFlakeConnectorConfig
from sqlalchemy import create_engine
Expand All @@ -16,7 +15,7 @@ class SnowFlakeConnector(SQLConnector):
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""

def __init__(self, config: SnowFlakeConnectorConfig):
def __init__(self, config: Union[SnowFlakeConnectorConfig, dict]):
"""
Initialize the SnowFlake connector with the given configuration.

Expand All @@ -25,18 +24,16 @@ def __init__(self, config: SnowFlakeConnectorConfig):
"""
config["dialect"] = "snowflake"

if "account" not in config and os.getenv("SNOWFLAKE_HOST"):
config["account"] = os.getenv("SNOWFLAKE_HOST")
if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"):
config["database"] = os.getenv("SNOWFLAKE_DATABASE")
if "warehouse" not in config and os.getenv("SNOWFLAKE_WAREHOUSE"):
config["warehouse"] = os.getenv("SNOWFLAKE_WAREHOUSE")
if "dbSchema" not in config and os.getenv("SNOWFLAKE_SCHEMA"):
config["dbSchema"] = os.getenv("SNOWFLAKE_SCHEMA")
if "username" not in config and os.getenv("SNOWFLAKE_USERNAME"):
config["username"] = os.getenv("SNOWFLAKE_USERNAME")
if "password" not in config and os.getenv("SNOWFLAKE_PASSWORD"):
config["password"] = os.getenv("SNOWFLAKE_PASSWORD")
if isinstance(config, dict):
snowflake_env_vars = {
"account": "SNOWFLAKE_HOST",
"database": "SNOWFLAKE_DATABASE",
"warehouse": "SNOWFLAKE_WAREHOUSE",
"dbSchema": "SNOWFLAKE_SCHEMA",
"username": "SNOWFLAKE_USERNAME",
"password": "SNOWFLAKE_PASSWORD",
}
config = self._populate_config_from_env(config, snowflake_env_vars)

super().__init__(config)

Expand Down Expand Up @@ -88,8 +85,7 @@ def __repr__(self):
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"username={self._config.username} "
f"password={self._config.password} Account={self._config.account} "
f"Account={self._config.account} "
f"warehouse={self._config.warehouse} "
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
Expand Down
62 changes: 25 additions & 37 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pandas as pd
from .base import BaseConnector, SQLConnectorConfig
from .base import BaseConnectorConfig
from sqlalchemy import create_engine, sql, text, select, asc
from sqlalchemy import create_engine, text, select, asc

from functools import cached_property, cache
import hashlib
from ..helpers.path import find_project_root
Expand Down Expand Up @@ -94,8 +95,7 @@ def __repr__(self):
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"driver={self._config.driver} username={self._config.username} "
f"password={self._config.password} host={self._config.host} "
f"driver={self._config.driver} host={self._config.host} "
f"port={str(self._config.port)} database={self._config.database} "
f"table={self._config.table}>"
)
Expand Down Expand Up @@ -279,10 +279,7 @@ def rows_count(self):
)

# Run a SQL query to get the number of rows
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
"WHERE table_name = :table_name"
).bindparams(table_name=self._config.table)
query = select(text("COUNT(*)")).select_from(text(self._config.table))

# Return the number of rows
self._rows_count = self._connection.execute(query).fetchone()[0]
Comment on lines 279 to 285
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SQL query to get the number of rows has been simplified. The new query directly counts the rows in the table, which should be more efficient than the previous approach of counting columns in the information_schema.columns table. However, ensure that the user executing this query has sufficient permissions on the target table.

Expand All @@ -307,14 +304,7 @@ def columns_count(self):
f"{self._config.dialect}"
)

# Run a SQL query to get the number of columns
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
f"WHERE table_name = '{self._config.table}'"
)

# Return the number of columns
self._columns_count = self._connection.execute(query).fetchone()[0]
self._columns_count = len(self.head().columns)
return self._columns_count
gventuri marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines 304 to 308
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The columns_count method now uses the len function on the columns returned by the head method instead of running a SQL query. This change could improve performance by reducing the number of database queries. However, make sure that the head method always returns the correct and complete set of columns.


def _get_column_hash(self, include_additional_filters: bool = False):
Expand Down Expand Up @@ -368,7 +358,7 @@ class MySQLConnector(SQLConnector):
MySQL connectors are used to connect to MySQL databases.
"""

def __init__(self, config: SQLConnectorConfig):
def __init__(self, config: Union[SQLConnectorConfig, dict]):
"""
Initialize the MySQL connector with the given configuration.

Expand All @@ -378,16 +368,15 @@ def __init__(self, config: SQLConnectorConfig):
config["dialect"] = "mysql"
config["driver"] = "pymysql"

if "host" not in config and os.getenv("MYSQL_HOST"):
config["host"] = os.getenv("MYSQL_HOST")
if "port" not in config and os.getenv("MYSQL_PORT"):
config["port"] = os.getenv("MYSQL_PORT")
if "database" not in config and os.getenv("MYSQL_DATABASE"):
config["database"] = os.getenv("MYSQL_DATABASE")
if "username" not in config and os.getenv("MYSQL_USERNAME"):
config["username"] = os.getenv("MYSQL_USERNAME")
if "password" not in config and os.getenv("MYSQL_PASSWORD"):
config["password"] = os.getenv("MYSQL_PASSWORD")
if isinstance(config, dict):
mysql_env_vars = {
"host": "MYSQL_HOST",
"port": "MYSQL_PORT",
"database": "MYSQL_DATABASE",
"username": "MYSQL_USERNAME",
"password": "MYSQL_PASSWORD",
}
config = self._populate_config_from_env(config, mysql_env_vars)

super().__init__(config)

Expand All @@ -397,7 +386,7 @@ class PostgreSQLConnector(SQLConnector):
PostgreSQL connectors are used to connect to PostgreSQL databases.
"""

def __init__(self, config: SQLConnectorConfig):
def __init__(self, config: Union[SQLConnectorConfig, dict]):
"""
Initialize the PostgreSQL connector with the given configuration.

Expand All @@ -407,16 +396,15 @@ def __init__(self, config: SQLConnectorConfig):
config["dialect"] = "postgresql"
config["driver"] = "psycopg2"

if "host" not in config and os.getenv("POSTGRESQL_HOST"):
config["host"] = os.getenv("POSTGRESQL_HOST")
if "port" not in config and os.getenv("POSTGRESQL_PORT"):
config["port"] = os.getenv("POSTGRESQL_PORT")
if "database" not in config and os.getenv("POSTGRESQL_DATABASE"):
config["database"] = os.getenv("POSTGRESQL_DATABASE")
if "username" not in config and os.getenv("POSTGRESQL_USERNAME"):
config["username"] = os.getenv("POSTGRESQL_USERNAME")
if "password" not in config and os.getenv("POSTGRESQL_PASSWORD"):
config["password"] = os.getenv("POSTGRESQL_PASSWORD")
if isinstance(config, dict):
postgresql_env_vars = {
"host": "POSTGRESQL_HOST",
"port": "POSTGRESQL_PORT",
"database": "POSTGRESQL_DATABASE",
"username": "POSTGRESQL_USERNAME",
"password": "POSTGRESQL_PASSWORD",
}
config = self._populate_config_from_env(config, postgresql_env_vars)

super().__init__(config)

Expand Down
Loading