Skip to content

Commit

Permalink
Fix: Databricks connection issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Sep 20, 2023
1 parent a7da1a5 commit 70c71ed
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 191 deletions.
16 changes: 8 additions & 8 deletions examples/from_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@

databricks_connector = DatabricksConnector(
config={
"host": "ehxzojy-ue47135",
"database": "SNOWFLAKE_SAMPLE_DATA",
"token": "",
"host": "adb-*****.azuredatabricks.net",
"database": "default",
"token": "dapidfd412321",
"port": 443,
"table": "lineitem",
"httpPath": "tpch_sf1",
"table": "loan_payments_data",
"httpPath": "/sql/1.0/warehouses/213421312",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["l_quantity", ">", "49"]
["loan_status", "=", "PAIDOFF"],
],
}
)

llm = OpenAI(api_token="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")
llm = OpenAI("OPEN_API_KEY")
df = SmartDataframe(databricks_connector, config={"llm": llm})

response = df.chat("How many records has status 'F'?")
response = df.chat("How many people from the United states?")
print(response)
14 changes: 2 additions & 12 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,7 @@ def rows_count(self):
)

# Run a SQL query to get the number of rows
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
"WHERE table_name = :table_name"
).bindparams(table_name=self._config.table)
query = sql.text(f"SELECT COUNT(*) FROM {self._config.table}")

# Return the number of rows
self._rows_count = self._connection.execute(query).fetchone()[0]
Expand All @@ -307,14 +304,7 @@ def columns_count(self):
f"{self._config.dialect}"
)

# Run a SQL query to get the number of columns
query = sql.text(
"SELECT COUNT(*) FROM information_schema.columns "
f"WHERE table_name = '{self._config.table}'"
)

# Return the number of columns
self._columns_count = self._connection.execute(query).fetchone()[0]
self._columns_count = len(self.head().columns)
return self._columns_count

def _get_column_hash(self, include_additional_filters: bool = False):
Expand Down
392 changes: 234 additions & 158 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ coverage = "^7.2.7"
google-cloud-aiplatform = "^1.26.1"

[tool.poetry.extras]
connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy", "databricks-sql-connector"]
connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy", "sqlalchemy-databricks"]
google-ai = ["google-generativeai", "google-cloud-aiplatform"]
google-sheets = ["beautifulsoup4"]
excel = ["openpyxl"]
Expand Down
8 changes: 4 additions & 4 deletions tests/connectors/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_rows_count_property(self):
def test_columns_count_property(self):
# Test columns_count property
self.connector._columns_count = None
self.mock_connection.execute.return_value.fetchone.return_value = (
8,
) # Sample columns count
mock_df = Mock()
mock_df.columns = ["Column1", "Column2"]
self.connector.head = Mock(return_value=mock_df)
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 8)
self.assertEqual(columns_count, 2)

def test_column_hash_property(self):
# Test column_hash property
Expand Down
8 changes: 4 additions & 4 deletions tests/connectors/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_rows_count_property(self):
def test_columns_count_property(self):
# Test columns_count property
self.connector._columns_count = None
self.mock_connection.execute.return_value.fetchone.return_value = (
8,
) # Sample columns count
mock_df = Mock()
mock_df.columns = ["Column1", "Column2"]
self.connector.head = Mock(return_value=mock_df)
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 8)
self.assertEqual(columns_count, 2)

def test_column_hash_property(self):
# Test column_hash property
Expand Down
8 changes: 4 additions & 4 deletions tests/connectors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_rows_count_property(self):
def test_columns_count_property(self):
# Test columns_count property
self.connector._columns_count = None
self.mock_connection.execute.return_value.fetchone.return_value = (
8,
) # Sample columns count
mock_df = Mock()
mock_df.columns = ["Column1", "Column2"]
self.connector.head = Mock(return_value=mock_df)
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 8)
self.assertEqual(columns_count, 2)

def test_column_hash_property(self):
# Test column_hash property
Expand Down

0 comments on commit 70c71ed

Please sign in to comment.