From 47761a18a1c17c7ab2110f128ed03971cf6d6c51 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 29 Oct 2024 21:58:09 -0700 Subject: [PATCH] ruff Signed-off-by: Prithvi Kannan --- tests/databricks_ai_bridge/test_genie.py | 40 ++++++++++++++---------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index 3957a32..248ebba 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from databricks_ai_bridge.genie import Genie, _parse_query_result, _count_tokens +from databricks_ai_bridge.genie import Genie, _count_tokens, _parse_query_result @pytest.fixture @@ -67,6 +67,7 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client): result = genie.poll_for_result("123", "456") assert result == pd.DataFrame().to_markdown() + def test_poll_for_result_failed(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"status": "FAILED"}, @@ -77,8 +78,10 @@ def test_poll_for_result_failed(genie, mock_workspace_client): def test_poll_for_result_max_iterations(genie, mock_workspace_client): # patch MAX_ITERATIONS to 2 for this test and sleep to avoid delays - with patch("databricks_ai_bridge.genie.MAX_ITERATIONS", 2), \ - patch("time.sleep", return_value=None): + with ( + patch("databricks_ai_bridge.genie.MAX_ITERATIONS", 2), + patch("time.sleep", return_value=None), + ): mock_workspace_client.genie._api.do.side_effect = [ {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, { @@ -95,11 +98,12 @@ def test_poll_for_result_max_iterations(genie, mock_workspace_client): "statement_response": { "status": {"state": "RUNNING"}, } - } + }, ] result = genie.poll_for_result("123", "456") assert result is None + def test_ask_question(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"conversation_id": "123", "message_id": "456"}, @@ -172,6 +176,7 @@ def test_parse_query_result_with_null_values(): ) assert result == expected_df.to_markdown() + def test_parse_query_result_trims_large_data(): # patch MAX_TOKENS_OF_DATA to 100 for this test with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", 100): @@ -201,15 +206,18 @@ def test_parse_query_result_trims_large_data(): }, } result = _parse_query_result(resp) - assert result == pd.DataFrame( - { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "created_at": [ - datetime(2023, 10, 1).date(), - datetime(2023, 10, 2).date(), - datetime(2023, 10, 3).date(), - ] - } - ).to_markdown() - assert _count_tokens(result) <= 100 \ No newline at end of file + assert ( + result + == pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "created_at": [ + datetime(2023, 10, 1).date(), + datetime(2023, 10, 2).date(), + datetime(2023, 10, 3).date(), + ], + } + ).to_markdown() + ) + assert _count_tokens(result) <= 100