From 66ce7238c5eed6cb04b0b9d3ca1d3ea5a462c01e Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Wed, 16 Oct 2024 22:59:52 -0700 Subject: [PATCH 01/10] Create Genie API wrapper Signed-off-by: Prithvi Kannan --- src/databricks_ai_bridge/genie.py | 116 ++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/databricks_ai_bridge/genie.py diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py new file mode 100644 index 0000000..fcbd955 --- /dev/null +++ b/src/databricks_ai_bridge/genie.py @@ -0,0 +1,116 @@ +import time +from datetime import datetime +from typing import Union + +import pandas as pd +from databricks.sdk import WorkspaceClient + + +def _parse_query_result(resp) -> Union[str, pd.DataFrame]: + columns = resp["manifest"]["schema"]["columns"] + header = [str(col["name"]) for col in columns] + rows = [] + output = resp["result"] + if not output: + return "EMPTY" + for item in resp["result"]["data_typed_array"]: + row = [] + for column, value in zip(columns, item["values"]): + type_name = column["type_name"] + str_value = value.get("str", None) + if str_value is None: + row.append(None) + continue + match type_name: + case "INT" | "LONG" | "SHORT" | "BYTE": + row.append(int(str_value)) + case "FLOAT" | "DOUBLE" | "DECIMAL": + row.append(float(str_value)) + case "BOOLEAN": + row.append(str_value.lower() == "true") + case "DATE": + row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) + case "TIMESTAMP": + row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) + case "BINARY": + row.append(bytes(str_value, "utf-8")) + case _: + row.append(str_value) + rows.append(row) + query_result = pd.DataFrame(rows, columns=header).to_string() + return query_result + + +class Genie: + def __init__(self, space_id): + self.space_id = space_id + workspace_client = WorkspaceClient() + self.genie = workspace_client.genie + self.headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + + def start_conversation(self, content): + resp = self.genie._api.do( + "POST", + f"/api/2.0/genie/spaces/{self.space_id}/start-conversation", + body={"content": content}, + headers=self.headers, + ) + return resp + + def create_message(self, conversation_id, content): + resp = self.genie._api.do( + "POST", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages", + body={"content": content}, + headers=self.headers, + ) + return resp + + def poll_for_result(self, conversation_id, message_id): + def poll_result(): + while True: + resp = self.genie._api.do( + "GET", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}", + headers=self.headers, + ) + if resp["status"] == "EXECUTING_QUERY": + sql = next(r for r in resp["attachments"] if "query" in r)["query"][ + "query" + ] + # print(f"SQL: {sql}") + return poll_query_results() + elif resp["status"] == "COMPLETED": + return next(r for r in resp["attachments"] if "text" in r)["text"][ + "content" + ] + else: + # print(f"Waiting...: {resp['status']}") + time.sleep(5) + + def poll_query_results(): + while True: + resp = self.genie._api.do( + "GET", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result", + headers=self.headers, + )["statement_response"] + state = resp["status"]["state"] + if state == "SUCCEEDED": + return _parse_query_result(resp) + elif state == "RUNNING" or state == "PENDING": + # print(f"Waiting for query result...") + time.sleep(5) + else: + # print(f"No query result: {resp['state']}") + return None + + return poll_result() + + def ask_question(self, question): + resp = self.start_conversation(question) + # TODO (prithvi): return the query and the result + return self.poll_for_result(resp["conversation_id"], resp["message_id"]) \ No newline at end of file From ac93eae669da72d13a72eb69a238cb42443aa680 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Wed, 16 Oct 2024 23:00:11 -0700 Subject: [PATCH 02/10] package setup Signed-off-by: Prithvi Kannan --- CONTRIBUTING.md | 9 +++++++++ requirements/dev-requirements.txt | 1 + 2 files changed, 10 insertions(+) create mode 100644 CONTRIBUTING.md create mode 100644 requirements/dev-requirements.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..1673cd6 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,9 @@ +Setting up dev environment + +Create a conda environement and install dev requirements + +``` +conda create --name databricks-ai-dev-env python=3.10 +conda activate databricks-ai-dev-env +pip install -r requirements/dev-requirements.txt +``` diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt new file mode 100644 index 0000000..4812a54 --- /dev/null +++ b/requirements/dev-requirements.txt @@ -0,0 +1 @@ +databricks-sdk>=0.34.0 From 188ac831a57321d03a554ef61e51026fb37c1c0b Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Wed, 16 Oct 2024 23:14:11 -0700 Subject: [PATCH 03/10] Add genie tests and resources to run Signed-off-by: Prithvi Kannan --- CONTRIBUTING.md | 3 +- pyproject.toml | 83 ++++++++++++++++++++++++ requirements/dev-requirements.txt | 1 - src/__init__.py | 0 src/databricks_ai_bridge/version.py | 1 + tests/databricks_ai_bridge/test_genie.py | 80 +++++++++++++++++++++++ 6 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 pyproject.toml delete mode 100644 requirements/dev-requirements.txt delete mode 100644 src/__init__.py create mode 100644 src/databricks_ai_bridge/version.py create mode 100644 tests/databricks_ai_bridge/test_genie.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1673cd6..7d00e19 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,5 +5,6 @@ Create a conda environement and install dev requirements ``` conda create --name databricks-ai-dev-env python=3.10 conda activate databricks-ai-dev-env -pip install -r requirements/dev-requirements.txt +pip install -e ".[databricks-dev]" +pip install -r requirements/lint-requirements.txt ``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6e76c92 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,83 @@ +[project] +name = "databricks-ai-bridge" +version = "0.0.1" +description = "Official Python library for Databricks AI support" +authors = [ + { name="Prithvi Kannan", email="prithvi.kannan@databricks.com" }, +] +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "typing_extensions", + "pydantic" +] + +[project.license] +file = "LICENSE.txt" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = [ + "src/*" +] + +[tool.hatch.build.targets.wheel] +packages = ["src/databricks_ai_bridge"] + +[project.optional-dependencies] +databricks = [ + "databricks-sdk>=0.34.0", + "pandas", +] +databricks-dev = [ + "hatch", + "pytest", + "databricks-sdk>=0.34.0", + "pandas", + "ruff==0.6.4", +] +dev = [ + "hatch", + "pytest", + "databricks-sdk>=0.34.0", + "pandas", + "ruff==0.6.4", +] + +[tool.ruff] +line-length = 100 +target-version = "py39" + +[tool.ruff.lint] +select = [ + # isort + "I", + # bugbear rules + "B", + # remove unused imports + "F401", + # bare except statements + "E722", + # print statements + "T201", + "T203", + # misuse of typing.TYPE_CHECKING + "TCH004", + # import rules + "TID251", + # undefined-local-with-import-star + "F403", +] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.pytest.ini_options] +pythonpath = ["src"] diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt deleted file mode 100644 index 4812a54..0000000 --- a/requirements/dev-requirements.txt +++ /dev/null @@ -1 +0,0 @@ -databricks-sdk>=0.34.0 diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/databricks_ai_bridge/version.py b/src/databricks_ai_bridge/version.py new file mode 100644 index 0000000..901e511 --- /dev/null +++ b/src/databricks_ai_bridge/version.py @@ -0,0 +1 @@ +VERSION = "0.0.1" diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py new file mode 100644 index 0000000..05a0dd7 --- /dev/null +++ b/tests/databricks_ai_bridge/test_genie.py @@ -0,0 +1,80 @@ +import pytest +from unittest.mock import MagicMock, patch +from databricks_ai_bridge.genie import Genie, _parse_query_result +import pandas as pd + +@pytest.fixture +def mock_workspace_client(): + with patch("databricks_ai_bridge.genie.WorkspaceClient") as MockWorkspaceClient: + mock_client = MockWorkspaceClient.return_value + yield mock_client + +@pytest.fixture +def genie(mock_workspace_client): + return Genie(space_id="test_space_id") + +def test_start_conversation(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.return_value = {"conversation_id": "123"} + response = genie.start_conversation("Hello") + assert response == {"conversation_id": "123"} + mock_workspace_client.genie._api.do.assert_called_once_with( + "POST", + "/api/2.0/genie/spaces/test_space_id/start-conversation", + body={"content": "Hello"}, + headers=genie.headers, + ) + +def test_create_message(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.return_value = {"message_id": "456"} + response = genie.create_message("123", "Hello again") + assert response == {"message_id": "456"} + mock_workspace_client.genie._api.do.assert_called_once_with( + "POST", + "/api/2.0/genie/spaces/test_space_id/conversations/123/messages", + body={"content": "Hello again"}, + headers=genie.headers, + ) + +def test_poll_for_result_completed(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]}, + ] + result = genie.poll_for_result("123", "456") + assert result == "Result" + +def test_poll_for_result_executing_query(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, + {"statement_response": {"status": {"state": "SUCCEEDED"}, "result": {"data_typed_array": [], "manifest": {"schema": {"columns": []}}}}}, + ] + result = genie.poll_for_result("123", "456") + assert result == "EMPTY" + +def test_ask_question(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"conversation_id": "123", "message_id": "456"}, + {"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]}, + ] + result = genie.ask_question("What is the meaning of life?") + assert result == "Answer" + +def test_parse_query_result(): + resp = { + "manifest": { + "schema": { + "columns": [ + {"name": "col1", "type_name": "STRING"}, + {"name": "col2", "type_name": "INT"}, + ] + } + }, + "result": { + "data_typed_array": [ + {"values": [{"str": "value1"}, {"str": "1"}]}, + {"values": [{"str": "value2"}, {"str": "2"}]}, + ] + } + } + expected_df = pd.DataFrame({"col1": ["value1", "value2"], "col2": [1, 2]}) + result = _parse_query_result(resp) + assert result == expected_df.to_string() \ No newline at end of file From ded9ec25a38d6a4cd19b9baff3ef9d2c6fb3140e Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Wed, 16 Oct 2024 23:18:26 -0700 Subject: [PATCH 04/10] update tests Signed-off-by: Prithvi Kannan --- tests/databricks_ai_bridge/test_genie.py | 62 +++++++++++++++++++++--- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index 05a0dd7..ac65d17 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -1,3 +1,4 @@ +from datetime import datetime import pytest from unittest.mock import MagicMock, patch from databricks_ai_bridge.genie import Genie, _parse_query_result @@ -57,24 +58,71 @@ def test_ask_question(genie, mock_workspace_client): ] result = genie.ask_question("What is the meaning of life?") assert result == "Answer" + +def test_parse_query_result_empty(): + resp = { + "manifest": { + "schema": { + "columns": [] + } + }, + "result": None + } + result = _parse_query_result(resp) + assert result == "EMPTY" + +def test_parse_query_result_with_data(): + resp = { + "manifest": { + "schema": { + "columns": [ + {"name": "id", "type_name": "INT"}, + {"name": "name", "type_name": "STRING"}, + {"name": "created_at", "type_name": "TIMESTAMP"}, + ] + } + }, + "result": { + "data_typed_array": [ + {"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, + {"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, + ] + } + } + result = _parse_query_result(resp) + expected_df = pd.DataFrame( + { + "id": [1, 2], + "name": ["Alice", "Bob"], + "created_at": [datetime(2023, 10, 1).date(), datetime(2023, 10, 2).date()], + } + ) + assert result == expected_df.to_string() -def test_parse_query_result(): +def test_parse_query_result_with_null_values(): resp = { "manifest": { "schema": { "columns": [ - {"name": "col1", "type_name": "STRING"}, - {"name": "col2", "type_name": "INT"}, + {"name": "id", "type_name": "INT"}, + {"name": "name", "type_name": "STRING"}, + {"name": "created_at", "type_name": "TIMESTAMP"}, ] } }, "result": { "data_typed_array": [ - {"values": [{"str": "value1"}, {"str": "1"}]}, - {"values": [{"str": "value2"}, {"str": "2"}]}, + {"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]}, + {"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]}, ] } } - expected_df = pd.DataFrame({"col1": ["value1", "value2"], "col2": [1, 2]}) result = _parse_query_result(resp) - assert result == expected_df.to_string() \ No newline at end of file + expected_df = pd.DataFrame( + { + "id": [1, 2], + "name": [None, "Bob"], + "created_at": [datetime(2023, 10, 1).date(), None], + } + ) + assert result == expected_df.to_string() From 0fa7b7066df8795ad247f48ffbfe8167838ceaa8 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Wed, 16 Oct 2024 23:26:28 -0700 Subject: [PATCH 05/10] create github workflow Signed-off-by: Prithvi Kannan --- .github/workflows/main.yml | 57 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 .github/workflows/main.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..dc361fa --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,57 @@ +name: tests + +on: + push: + branches: + - master + pull_request: + types: + - opened + - synchronize + - reopened + - ready_for_review + +jobs: + lint: + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.9' + - name: Install dependencies + run: | + pip install -r requirements/lint-requirements.txt + - name: Lint Python code with ruff + run: | + ruff check . + ruff format --check . + - name: Lint YAML files with yamllint + run: yamllint . + + core_test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10'] + timeout-minutes: 20 + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install .[dev] + - name: Run tests with pydantic v1 + run: | + pip install 'pydantic<2' + pytest tests/ + - name: Run tests with pydantic v2 + run: | + pip install 'pydantic>=2' + pytest tests/ From 845776644954cf68a8bbdd6cbf3a8539b372fa06 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 13:34:45 -0700 Subject: [PATCH 06/10] remove match Signed-off-by: Prithvi Kannan --- src/databricks_ai_bridge/genie.py | 33 +++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index fcbd955..4972ef4 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -13,6 +13,7 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: output = resp["result"] if not output: return "EMPTY" + for item in resp["result"]["data_typed_array"]: row = [] for column, value in zip(columns, item["values"]): @@ -21,22 +22,24 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: if str_value is None: row.append(None) continue - match type_name: - case "INT" | "LONG" | "SHORT" | "BYTE": - row.append(int(str_value)) - case "FLOAT" | "DOUBLE" | "DECIMAL": - row.append(float(str_value)) - case "BOOLEAN": - row.append(str_value.lower() == "true") - case "DATE": - row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) - case "TIMESTAMP": - row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) - case "BINARY": - row.append(bytes(str_value, "utf-8")) - case _: - row.append(str_value) + + if type_name in ["INT", "LONG", "SHORT", "BYTE"]: + row.append(int(str_value)) + elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]: + row.append(float(str_value)) + elif type_name == "BOOLEAN": + row.append(str_value.lower() == "true") + elif type_name == "DATE": + row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) + elif type_name == "TIMESTAMP": + row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) + elif type_name == "BINARY": + row.append(bytes(str_value, "utf-8")) + else: + row.append(str_value) + rows.append(row) + query_result = pd.DataFrame(rows, columns=header).to_string() return query_result From 96ab57a64f80c8ab7c49317099f44945b2366f72 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 13:46:05 -0700 Subject: [PATCH 07/10] fix test Signed-off-by: Prithvi Kannan --- pyproject.toml | 2 +- tests/databricks_ai_bridge/test_genie.py | 37 ++++++++++++++++++++---- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e76c92..775d184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ { name="Prithvi Kannan", email="prithvi.kannan@databricks.com" }, ] readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.8" dependencies = [ "typing_extensions", "pydantic" diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index ac65d17..cff579c 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -1,8 +1,11 @@ from datetime import datetime +from unittest.mock import patch + +import pandas as pd import pytest -from unittest.mock import MagicMock, patch + from databricks_ai_bridge.genie import Genie, _parse_query_result -import pandas as pd + @pytest.fixture def mock_workspace_client(): @@ -45,11 +48,35 @@ def test_poll_for_result_completed(genie, mock_workspace_client): def test_poll_for_result_executing_query(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ - {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, - {"statement_response": {"status": {"state": "SUCCEEDED"}, "result": {"data_typed_array": [], "manifest": {"schema": {"columns": []}}}}}, + { + "status": "EXECUTING_QUERY", + "attachments": [ + { + "query": { + "query": "SELECT *" + } + } + ] + }, + { + "statement_response": { + "status": { + "state": "SUCCEEDED" + }, + "manifest": { + "schema": { + "columns": [] + } + }, + "result": { + "data_typed_array": [], + } + } + } ] result = genie.poll_for_result("123", "456") - assert result == "EMPTY" + assert result == pd.DataFrame().to_string() + def test_ask_question(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ From 1b17d84ba1fbfc7389304fd7c95cd267d6296293 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 13:50:55 -0700 Subject: [PATCH 08/10] ruff Signed-off-by: Prithvi Kannan --- src/databricks_ai_bridge/genie.py | 16 +++----- tests/databricks_ai_bridge/test_genie.py | 48 +++++++++--------------- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 4972ef4..d6175ae 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -13,7 +13,7 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: output = resp["result"] if not output: return "EMPTY" - + for item in resp["result"]["data_typed_array"]: row = [] for column, value in zip(columns, item["values"]): @@ -37,9 +37,9 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: row.append(bytes(str_value, "utf-8")) else: row.append(str_value) - + rows.append(row) - + query_result = pd.DataFrame(rows, columns=header).to_string() return query_result @@ -81,15 +81,11 @@ def poll_result(): headers=self.headers, ) if resp["status"] == "EXECUTING_QUERY": - sql = next(r for r in resp["attachments"] if "query" in r)["query"][ - "query" - ] + sql = next(r for r in resp["attachments"] if "query" in r)["query"]["query"] # print(f"SQL: {sql}") return poll_query_results() elif resp["status"] == "COMPLETED": - return next(r for r in resp["attachments"] if "text" in r)["text"][ - "content" - ] + return next(r for r in resp["attachments"] if "text" in r)["text"]["content"] else: # print(f"Waiting...: {resp['status']}") time.sleep(5) @@ -116,4 +112,4 @@ def poll_query_results(): def ask_question(self, question): resp = self.start_conversation(question) # TODO (prithvi): return the query and the result - return self.poll_for_result(resp["conversation_id"], resp["message_id"]) \ No newline at end of file + return self.poll_for_result(resp["conversation_id"], resp["message_id"]) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index cff579c..edd720a 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -13,10 +13,12 @@ def mock_workspace_client(): mock_client = MockWorkspaceClient.return_value yield mock_client + @pytest.fixture def genie(mock_workspace_client): return Genie(space_id="test_space_id") + def test_start_conversation(genie, mock_workspace_client): mock_workspace_client.genie._api.do.return_value = {"conversation_id": "123"} response = genie.start_conversation("Hello") @@ -28,6 +30,7 @@ def test_start_conversation(genie, mock_workspace_client): headers=genie.headers, ) + def test_create_message(genie, mock_workspace_client): mock_workspace_client.genie._api.do.return_value = {"message_id": "456"} response = genie.create_message("123", "Hello again") @@ -39,6 +42,7 @@ def test_create_message(genie, mock_workspace_client): headers=genie.headers, ) + def test_poll_for_result_completed(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]}, @@ -46,33 +50,19 @@ def test_poll_for_result_completed(genie, mock_workspace_client): result = genie.poll_for_result("123", "456") assert result == "Result" + def test_poll_for_result_executing_query(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ - { - "status": "EXECUTING_QUERY", - "attachments": [ - { - "query": { - "query": "SELECT *" - } - } - ] - }, + {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, { "statement_response": { - "status": { - "state": "SUCCEEDED" - }, - "manifest": { - "schema": { - "columns": [] - } - }, + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": {"columns": []}}, "result": { "data_typed_array": [], - } + }, } - } + }, ] result = genie.poll_for_result("123", "456") assert result == pd.DataFrame().to_string() @@ -85,19 +75,14 @@ def test_ask_question(genie, mock_workspace_client): ] result = genie.ask_question("What is the meaning of life?") assert result == "Answer" - + + def test_parse_query_result_empty(): - resp = { - "manifest": { - "schema": { - "columns": [] - } - }, - "result": None - } + resp = {"manifest": {"schema": {"columns": []}}, "result": None} result = _parse_query_result(resp) assert result == "EMPTY" + def test_parse_query_result_with_data(): resp = { "manifest": { @@ -114,7 +99,7 @@ def test_parse_query_result_with_data(): {"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, {"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, ] - } + }, } result = _parse_query_result(resp) expected_df = pd.DataFrame( @@ -126,6 +111,7 @@ def test_parse_query_result_with_data(): ) assert result == expected_df.to_string() + def test_parse_query_result_with_null_values(): resp = { "manifest": { @@ -142,7 +128,7 @@ def test_parse_query_result_with_null_values(): {"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]}, {"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]}, ] - } + }, } result = _parse_query_result(resp) expected_df = pd.DataFrame( From 4cadcbcd9247484500caf8cc124a46a0dd0ebe9a Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 13:54:36 -0700 Subject: [PATCH 09/10] logging Signed-off-by: Prithvi Kannan --- src/databricks_ai_bridge/genie.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index d6175ae..0797495 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -1,3 +1,4 @@ +import logging import time from datetime import datetime from typing import Union @@ -82,12 +83,12 @@ def poll_result(): ) if resp["status"] == "EXECUTING_QUERY": sql = next(r for r in resp["attachments"] if "query" in r)["query"]["query"] - # print(f"SQL: {sql}") + logging.debug(f"SQL: {sql}") return poll_query_results() elif resp["status"] == "COMPLETED": return next(r for r in resp["attachments"] if "text" in r)["text"]["content"] else: - # print(f"Waiting...: {resp['status']}") + logging.debug(f"Waiting...: {resp['status']}") time.sleep(5) def poll_query_results(): @@ -101,10 +102,10 @@ def poll_query_results(): if state == "SUCCEEDED": return _parse_query_result(resp) elif state == "RUNNING" or state == "PENDING": - # print(f"Waiting for query result...") + logging.debug("Waiting for query result...") time.sleep(5) else: - # print(f"No query result: {resp['state']}") + logging.debug(f"No query result: {resp['state']}") return None return poll_result() From 0e58f54ad7ec7812093ddbfeba56855a6e6d9c23 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 13:57:36 -0700 Subject: [PATCH 10/10] remove version file Signed-off-by: Prithvi Kannan --- src/databricks_ai_bridge/version.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 src/databricks_ai_bridge/version.py diff --git a/src/databricks_ai_bridge/version.py b/src/databricks_ai_bridge/version.py deleted file mode 100644 index 901e511..0000000 --- a/src/databricks_ai_bridge/version.py +++ /dev/null @@ -1 +0,0 @@ -VERSION = "0.0.1"