diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 50e0a89..66ab1dc 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -5,10 +5,16 @@ import pandas as pd from databricks.sdk import WorkspaceClient +import tiktoken MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS +# Define a function to count tokens +def _count_tokens(text): + encoding = tiktoken.encoding_for_model("gpt-4o") + return len(encoding.encode(text)) + def _parse_query_result(resp) -> Union[str, pd.DataFrame]: columns = resp["manifest"]["schema"]["columns"] header = [str(col["name"]) for col in columns] @@ -43,8 +49,25 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: rows.append(row) - query_result = pd.DataFrame(rows, columns=header).to_string() - return query_result + query_result = pd.DataFrame(rows, columns=header).to_markdown() + + # trim down from the total rows until we get under the token limit + trimmed_rows = len(rows) + tokens_used = _count_tokens(query_result) + while trimmed_rows > 0 and tokens_used > MAX_TOKENS_OF_DATA: + # convert to markdown + query_result = ( + pd.DataFrame(rows, columns=header).head(trimmed_rows).to_markdown() + ) + # keep trimming down until we get under the token limit + trimmed_rows -= 5 + # worst case, return None, which the Agent will handle and not display the query results + tokens_used = _count_tokens(query_result) + if trimmed_rows == 0: + query_result = None + tokens_used = 0 + + return query_result.strip() if query_result else query_result class Genie: