Skip to content

Commit

Permalink
feat[DatabricksConnector]: connector to connect to Databricks on Cloud (
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem authored and gventuri committed Sep 24, 2023
1 parent 36c1b2d commit a8e5afb
Show file tree
Hide file tree
Showing 11 changed files with 550 additions and 237 deletions.
28 changes: 28 additions & 0 deletions examples/from_databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Example of using PandasAI with a DataBricks"""

from pandasai import SmartDataframe
from pandasai.llm import OpenAI
from pandasai.connectors import DatabricksConnector


databricks_connector = DatabricksConnector(
config={
"host": "adb-*****.azuredatabricks.net",
"database": "default",
"token": "dapidfd412321",
"port": 443,
"table": "loan_payments_data",
"httpPath": "/sql/1.0/warehouses/213421312",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["loan_status", "=", "PAIDOFF"],
],
}
)

llm = OpenAI("OPEN_API_KEY")
df = SmartDataframe(databricks_connector, config={"llm": llm})

response = df.chat("How many people from the United states?")
print(response)
2 changes: 2 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .snowflake import SnowFlakeConnector
from .databricks import DatabricksConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = [
Expand All @@ -16,4 +17,5 @@
"PostgreSQLConnector",
"YahooFinanceConnector",
"SnowFlakeConnector",
"DatabricksConnector",
]
30 changes: 30 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from abc import ABC, abstractmethod
import os
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from pydantic import BaseModel
Expand Down Expand Up @@ -63,6 +64,17 @@ class SnowFlakeConnectorConfig(SQLBaseConnectorConfig):
warehouse: str


class DatabricksConnectorConfig(SQLBaseConnectorConfig):
"""
Connector configuration for DataBricks.
"""

host: str
port: int
token: str
httpPath: str


class BaseConnector(ABC):
"""
Base connector class to be extended by all connectors.
Expand Down Expand Up @@ -95,6 +107,24 @@ def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
"""
pass

def _populate_config_from_env(self, config: dict, envs_mapping: dict):
"""
Populate the configuration dictionary with values from environment variables
if not exists in the config.
Args:
config (dict): The configuration dictionary to be populated.
Returns:
dict: The populated configuration dictionary.
"""

for key, env_var in envs_mapping.items():
if key not in config and os.getenv(env_var):
config[key] = os.getenv(env_var)

return config

def _init_connection(self, config: BaseConnectorConfig):
"""
make connection to database
Expand Down
65 changes: 65 additions & 0 deletions pandasai/connectors/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Databricks Connector to connects you to your Databricks SQL Warhouse on
Azure, AWS and GCP
"""

from .base import BaseConnectorConfig, DatabricksConnectorConfig
from sqlalchemy import create_engine
from typing import Union
from .sql import SQLConnector


class DatabricksConnector(SQLConnector):
"""
Databricks connectors are used to connect to Databricks Data Cloud.
"""

def __init__(self, config: Union[DatabricksConnectorConfig, dict]):
"""
Initialize the Databricks connector with the given configuration.
Args:
config (ConnectorConfig): The configuration for the Databricks connector.
"""
config["dialect"] = "databricks"
if isinstance(config, dict):
env_vars = {
"token": "DATABRICKS_TOKEN",
"database": "DATABRICKS_DATABASE",
"host": "DATABRICKS_HOST",
"port": "DATABRICKS_PORT",
"httpPath": "DATABRICKS_HTTP_PATH",
}
config = self._populate_config_from_env(config, env_vars)

super().__init__(config)

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
return DatabricksConnectorConfig(**config)

def _init_connection(self, config: DatabricksConnectorConfig):
"""
Initialize Database Connection
Args:
config (DatabricksConnectorConfig): Configurations to load database
"""
self._engine = create_engine(
f"{config.dialect}://token:{config.token}@{config.host}:{config.port}?http_path={config.httpPath}"
)

self._connection = self._engine.connect()

def __repr__(self):
"""
Return the string representation of the Databricks connector.
Returns:
str: The string representation of the Databricks connector.
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)
28 changes: 12 additions & 16 deletions pandasai/connectors/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""

import os
import pandas as pd
from .base import BaseConnectorConfig, SnowFlakeConnectorConfig
from sqlalchemy import create_engine
Expand All @@ -16,7 +15,7 @@ class SnowFlakeConnector(SQLConnector):
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""

def __init__(self, config: SnowFlakeConnectorConfig):
def __init__(self, config: Union[SnowFlakeConnectorConfig, dict]):
"""
Initialize the SnowFlake connector with the given configuration.
Expand All @@ -25,18 +24,16 @@ def __init__(self, config: SnowFlakeConnectorConfig):
"""
config["dialect"] = "snowflake"

if "account" not in config and os.getenv("SNOWFLAKE_HOST"):
config["account"] = os.getenv("SNOWFLAKE_HOST")
if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"):
config["database"] = os.getenv("SNOWFLAKE_DATABASE")
if "warehouse" not in config and os.getenv("SNOWFLAKE_WAREHOUSE"):
config["warehouse"] = os.getenv("SNOWFLAKE_WAREHOUSE")
if "dbSchema" not in config and os.getenv("SNOWFLAKE_SCHEMA"):
config["dbSchema"] = os.getenv("SNOWFLAKE_SCHEMA")
if "username" not in config and os.getenv("SNOWFLAKE_USERNAME"):
config["username"] = os.getenv("SNOWFLAKE_USERNAME")
if "password" not in config and os.getenv("SNOWFLAKE_PASSWORD"):
config["password"] = os.getenv("SNOWFLAKE_PASSWORD")
if isinstance(config, dict):
snowflake_env_vars = {
"account": "SNOWFLAKE_HOST",
"database": "SNOWFLAKE_DATABASE",
"warehouse": "SNOWFLAKE_WAREHOUSE",
"dbSchema": "SNOWFLAKE_SCHEMA",
"username": "SNOWFLAKE_USERNAME",
"password": "SNOWFLAKE_PASSWORD",
}
config = self._populate_config_from_env(config, snowflake_env_vars)

super().__init__(config)

Expand Down Expand Up @@ -88,8 +85,7 @@ def __repr__(self):
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"username={self._config.username} "
f"password={self._config.password} Account={self._config.account} "
f"Account={self._config.account} "
f"warehouse={self._config.warehouse} "
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
Expand Down
62 changes: 25 additions & 37 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pandas as pd
from .base import BaseConnector, SQLConnectorConfig
from .base import BaseConnectorConfig
from sqlalchemy import create_engine, sql, text, select, asc
from sqlalchemy import create_engine, text, select, asc

from functools import cached_property, cache
import hashlib
from ..helpers.path import find_project_root
Expand Down Expand Up @@ -94,8 +95,7 @@ def __repr__(self):
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"driver={self._config.driver} username={self._config.username} "
f"password={self._config.password} host={self._config.host} "
f"driver={self._config.driver} host={self._config.host} "
f"port={str(self._config.port)} database={self._config.database} "
f"table={self._config.table}>"
)
Expand Down Expand Up @@ -279,10 +279,7 @@ def rows_count(self):
)

# Run a SQL query to get the number of rows
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
"WHERE table_name = :table_name"
).bindparams(table_name=self._config.table)
query = select(text("COUNT(*)")).select_from(text(self._config.table))

# Return the number of rows
self._rows_count = self._connection.execute(query).fetchone()[0]
Expand All @@ -307,14 +304,7 @@ def columns_count(self):
f"{self._config.dialect}"
)

# Run a SQL query to get the number of columns
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
f"WHERE table_name = '{self._config.table}'"
)

# Return the number of columns
self._columns_count = self._connection.execute(query).fetchone()[0]
self._columns_count = len(self.head().columns)
return self._columns_count

def _get_column_hash(self, include_additional_filters: bool = False):
Expand Down Expand Up @@ -368,7 +358,7 @@ class MySQLConnector(SQLConnector):
MySQL connectors are used to connect to MySQL databases.
"""

def __init__(self, config: SQLConnectorConfig):
def __init__(self, config: Union[SQLConnectorConfig, dict]):
"""
Initialize the MySQL connector with the given configuration.
Expand All @@ -378,16 +368,15 @@ def __init__(self, config: SQLConnectorConfig):
config["dialect"] = "mysql"
config["driver"] = "pymysql"

if "host" not in config and os.getenv("MYSQL_HOST"):
config["host"] = os.getenv("MYSQL_HOST")
if "port" not in config and os.getenv("MYSQL_PORT"):
config["port"] = os.getenv("MYSQL_PORT")
if "database" not in config and os.getenv("MYSQL_DATABASE"):
config["database"] = os.getenv("MYSQL_DATABASE")
if "username" not in config and os.getenv("MYSQL_USERNAME"):
config["username"] = os.getenv("MYSQL_USERNAME")
if "password" not in config and os.getenv("MYSQL_PASSWORD"):
config["password"] = os.getenv("MYSQL_PASSWORD")
if isinstance(config, dict):
mysql_env_vars = {
"host": "MYSQL_HOST",
"port": "MYSQL_PORT",
"database": "MYSQL_DATABASE",
"username": "MYSQL_USERNAME",
"password": "MYSQL_PASSWORD",
}
config = self._populate_config_from_env(config, mysql_env_vars)

super().__init__(config)

Expand All @@ -397,7 +386,7 @@ class PostgreSQLConnector(SQLConnector):
PostgreSQL connectors are used to connect to PostgreSQL databases.
"""

def __init__(self, config: SQLConnectorConfig):
def __init__(self, config: Union[SQLConnectorConfig, dict]):
"""
Initialize the PostgreSQL connector with the given configuration.
Expand All @@ -407,16 +396,15 @@ def __init__(self, config: SQLConnectorConfig):
config["dialect"] = "postgresql"
config["driver"] = "psycopg2"

if "host" not in config and os.getenv("POSTGRESQL_HOST"):
config["host"] = os.getenv("POSTGRESQL_HOST")
if "port" not in config and os.getenv("POSTGRESQL_PORT"):
config["port"] = os.getenv("POSTGRESQL_PORT")
if "database" not in config and os.getenv("POSTGRESQL_DATABASE"):
config["database"] = os.getenv("POSTGRESQL_DATABASE")
if "username" not in config and os.getenv("POSTGRESQL_USERNAME"):
config["username"] = os.getenv("POSTGRESQL_USERNAME")
if "password" not in config and os.getenv("POSTGRESQL_PASSWORD"):
config["password"] = os.getenv("POSTGRESQL_PASSWORD")
if isinstance(config, dict):
postgresql_env_vars = {
"host": "POSTGRESQL_HOST",
"port": "POSTGRESQL_PORT",
"database": "POSTGRESQL_DATABASE",
"username": "POSTGRESQL_USERNAME",
"password": "POSTGRESQL_PASSWORD",
}
config = self._populate_config_from_env(config, postgresql_env_vars)

super().__init__(config)

Expand Down
Loading

0 comments on commit a8e5afb

Please sign in to comment.