Skip to content

Commit

Permalink
Merge pull request #2 from databricks/setup-genie-core
Browse files Browse the repository at this point in the history
Create Genie API wrapper
  • Loading branch information
prithvikannan authored Oct 24, 2024
2 parents c6164db + 86a74fb commit b314f60
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 0 deletions.
116 changes: 116 additions & 0 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
import time
from datetime import datetime
from typing import Union

import pandas as pd
from databricks.sdk import WorkspaceClient


def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
columns = resp["manifest"]["schema"]["columns"]
header = [str(col["name"]) for col in columns]
rows = []
output = resp["result"]
if not output:
return "EMPTY"

for item in resp["result"]["data_typed_array"]:
row = []
for column, value in zip(columns, item["values"]):
type_name = column["type_name"]
str_value = value.get("str", None)
if str_value is None:
row.append(None)
continue

if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
row.append(int(str_value))
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
row.append(float(str_value))
elif type_name == "BOOLEAN":
row.append(str_value.lower() == "true")
elif type_name == "DATE":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "TIMESTAMP":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "BINARY":
row.append(bytes(str_value, "utf-8"))
else:
row.append(str_value)

rows.append(row)

query_result = pd.DataFrame(rows, columns=header).to_string()
return query_result


class Genie:
def __init__(self, space_id):
self.space_id = space_id
workspace_client = WorkspaceClient()
self.genie = workspace_client.genie
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}

def start_conversation(self, content):
resp = self.genie._api.do(
"POST",
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
body={"content": content},
headers=self.headers,
)
return resp

def create_message(self, conversation_id, content):
resp = self.genie._api.do(
"POST",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
body={"content": content},
headers=self.headers,
)
return resp

def poll_for_result(self, conversation_id, message_id):
def poll_result():
while True:
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
headers=self.headers,
)
if resp["status"] == "EXECUTING_QUERY":
sql = next(r for r in resp["attachments"] if "query" in r)["query"]["query"]
logging.debug(f"SQL: {sql}")
return poll_query_results()
elif resp["status"] == "COMPLETED":
return next(r for r in resp["attachments"] if "text" in r)["text"]["content"]
else:
logging.debug(f"Waiting...: {resp['status']}")
time.sleep(5)

def poll_query_results():
while True:
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result",
headers=self.headers,
)["statement_response"]
state = resp["status"]["state"]
if state == "SUCCEEDED":
return _parse_query_result(resp)
elif state == "RUNNING" or state == "PENDING":
logging.debug("Waiting for query result...")
time.sleep(5)
else:
logging.debug(f"No query result: {resp['state']}")
return None

return poll_result()

def ask_question(self, question):
resp = self.start_conversation(question)
# TODO (prithvi): return the query and the result
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
141 changes: 141 additions & 0 deletions tests/databricks_ai_bridge/test_genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from datetime import datetime
from unittest.mock import patch

import pandas as pd
import pytest

from databricks_ai_bridge.genie import Genie, _parse_query_result


@pytest.fixture
def mock_workspace_client():
with patch("databricks_ai_bridge.genie.WorkspaceClient") as MockWorkspaceClient:
mock_client = MockWorkspaceClient.return_value
yield mock_client


@pytest.fixture
def genie(mock_workspace_client):
return Genie(space_id="test_space_id")


def test_start_conversation(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.return_value = {"conversation_id": "123"}
response = genie.start_conversation("Hello")
assert response == {"conversation_id": "123"}
mock_workspace_client.genie._api.do.assert_called_once_with(
"POST",
"/api/2.0/genie/spaces/test_space_id/start-conversation",
body={"content": "Hello"},
headers=genie.headers,
)


def test_create_message(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.return_value = {"message_id": "456"}
response = genie.create_message("123", "Hello again")
assert response == {"message_id": "456"}
mock_workspace_client.genie._api.do.assert_called_once_with(
"POST",
"/api/2.0/genie/spaces/test_space_id/conversations/123/messages",
body={"content": "Hello again"},
headers=genie.headers,
)


def test_poll_for_result_completed(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]},
]
result = genie.poll_for_result("123", "456")
assert result == "Result"


def test_poll_for_result_executing_query(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]},
{
"statement_response": {
"status": {"state": "SUCCEEDED"},
"manifest": {"schema": {"columns": []}},
"result": {
"data_typed_array": [],
},
}
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_string()


def test_ask_question(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"conversation_id": "123", "message_id": "456"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]},
]
result = genie.ask_question("What is the meaning of life?")
assert result == "Answer"


def test_parse_query_result_empty():
resp = {"manifest": {"schema": {"columns": []}}, "result": None}
result = _parse_query_result(resp)
assert result == "EMPTY"


def test_parse_query_result_with_data():
resp = {
"manifest": {
"schema": {
"columns": [
{"name": "id", "type_name": "INT"},
{"name": "name", "type_name": "STRING"},
{"name": "created_at", "type_name": "TIMESTAMP"},
]
}
},
"result": {
"data_typed_array": [
{"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]},
{"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]},
]
},
}
result = _parse_query_result(resp)
expected_df = pd.DataFrame(
{
"id": [1, 2],
"name": ["Alice", "Bob"],
"created_at": [datetime(2023, 10, 1).date(), datetime(2023, 10, 2).date()],
}
)
assert result == expected_df.to_string()


def test_parse_query_result_with_null_values():
resp = {
"manifest": {
"schema": {
"columns": [
{"name": "id", "type_name": "INT"},
{"name": "name", "type_name": "STRING"},
{"name": "created_at", "type_name": "TIMESTAMP"},
]
}
},
"result": {
"data_typed_array": [
{"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]},
{"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]},
]
},
}
result = _parse_query_result(resp)
expected_df = pd.DataFrame(
{
"id": [1, 2],
"name": [None, "Bob"],
"created_at": [datetime(2023, 10, 1).date(), None],
}
)
assert result == expected_df.to_string()

0 comments on commit b314f60

Please sign in to comment.