From e9dfff751fa95dccec3e8a7e53d0228bba45b341 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 14:21:41 -0700 Subject: [PATCH 1/6] Create Langchain Genie Signed-off-by: Prithvi Kannan --- integrations/langchain/pyproject.toml | 68 ++++++++++++++++++ .../src/databricks_langchain/genie.py | 43 +++++++++++ integrations/langchain/tests/test_genie.py | 72 +++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 integrations/langchain/pyproject.toml create mode 100644 integrations/langchain/src/databricks_langchain/genie.py create mode 100644 integrations/langchain/tests/test_genie.py diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml new file mode 100644 index 0000000..5d06e6f --- /dev/null +++ b/integrations/langchain/pyproject.toml @@ -0,0 +1,68 @@ +[project] +name = "databricks-langchain" +version = "0.0.1" +description = "Support for Datarbricks AI support in LangChain" +authors = [ + { name="Prithvi Kannan", email="prithvi.kannan@databricks.com" }, +] +readme = "README.md" +license = { text="Apache-2.0" } +requires-python = ">=3.8" +dependencies = [ + "langchain>=0.2.0", + "langchain-community>=0.2.0", + "databricks-ai-bridge", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "typing_extensions", + "databricks-sdk>=0.34.0", + "ruff==0.6.4", + "langgraph", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = [ + "src/databricks_langchain/*" +] + +[tool.hatch.build.targets.wheel] +packages = ["src/databricks_langchain"] + +[tool.ruff] +line-length = 100 +target-version = "py39" + +[tool.ruff.lint] +select = [ + # isort + "I", + # bugbear rules + "B", + # remove unused imports + "F401", + # bare except statements + "E722", + # print statements + "T201", + "T203", + # misuse of typing.TYPE_CHECKING + "TCH004", + # import rules + "TID251", + # undefined-local-with-import-star + "F403", +] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + +[tool.ruff.lint.pydocstyle] +convention = "google" \ No newline at end of file diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py new file mode 100644 index 0000000..c832cd4 --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -0,0 +1,43 @@ +from databricks_ai_bridge.genie import Genie + +def _concat_messages_array(messages): + concatenated_message = "\n".join( + [ + f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" + if isinstance(message, dict) + else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}" + for message in messages + ] + ) + return concatenated_message + + +def _query_genie_as_agent(input, genie_space_id, genie_agent_name): + from langchain_core.messages import AIMessage + genie = Genie(genie_space_id) + + message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" + + # Concatenate messages to form the chat history + message += _concat_messages_array(input.get("messages")) + + # Send the message and wait for a response + genie_response = genie.ask_question(message) + + if genie_response: + return {"messages": [AIMessage(content=genie_response)]} + else: + return {"messages": [AIMessage(content="")]} + + +def create_genie_agent(genie_space_id, genie_agent_name="Genie"): + """Create a genie agent that can be used to query the API""" + from functools import partial + + from langchain_core.runnables import RunnableLambda + + # Create a partial function with the genie_space_id pre-filled + partial_genie_agent = partial(_query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name) + + # Use the partial function in the RunnableLambda + return RunnableLambda(partial_genie_agent) \ No newline at end of file diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py new file mode 100644 index 0000000..8fe3f66 --- /dev/null +++ b/integrations/langchain/tests/test_genie.py @@ -0,0 +1,72 @@ +import pytest +from unittest.mock import patch, MagicMock +from langchain_core.messages import AIMessage +from my_module import _concat_messages_array, _query_genie_as_agent, create_genie_agent + +def test_concat_messages_array(): + # Test a simple case with multiple messages + messages = [ + {"role": "user", "content": "What is the weather?"}, + {"role": "assistant", "content": "It is sunny."} + ] + result = _concat_messages_array(messages) + expected = "user: What is the weather?\nassistant: It is sunny." + assert result == expected + + # Test case with missing content + messages = [ + {"role": "user"}, + {"role": "assistant", "content": "I don't know."} + ] + result = _concat_messages_array(messages) + expected = "user: \nassistant: I don't know." + assert result == expected + + # Test case with non-dict message objects + class Message: + def __init__(self, role, content): + self.role = role + self.content = content + + messages = [ + Message("user", "Tell me a joke."), + Message("assistant", "Why did the chicken cross the road?") + ] + result = _concat_messages_array(messages) + expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" + assert result == expected + + +@patch('databricks_ai_bridge.genie.Genie') +def test_query_genie_as_agent(MockGenie): + # Mock the Genie class and its response + mock_genie = MockGenie.return_value + mock_genie.ask_question.return_value = "It is sunny." + + input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} + result = _query_genie_as_agent(input_data, "space-id", "Genie") + + expected_message = { + "messages": [AIMessage(content="It is sunny.")] + } + assert result == expected_message + + # Test the case when genie_response is empty + mock_genie.ask_question.return_value = None + result = _query_genie_as_agent(input_data, "space-id", "Genie") + + expected_message = { + "messages": [AIMessage(content="")] + } + assert result == expected_message + + +@patch('langchain_core.runnables.RunnableLambda') +def test_create_genie_agent(MockRunnableLambda): + mock_runnable = MockRunnableLambda.return_value + + agent = create_genie_agent("space-id", "Genie") + assert agent == mock_runnable + + # Check that the partial function is created with the correct arguments + MockRunnableLambda.assert_called() \ No newline at end of file From ab8b45026bc33711306fbc8a0ae797cb8303a803 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 22 Oct 2024 14:22:58 -0700 Subject: [PATCH 2/6] ruff Signed-off-by: Prithvi Kannan --- .../src/databricks_langchain/genie.py | 8 ++++-- integrations/langchain/tests/test_genie.py | 28 ++++++++----------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index c832cd4..0528bfe 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -1,5 +1,6 @@ from databricks_ai_bridge.genie import Genie + def _concat_messages_array(messages): concatenated_message = "\n".join( [ @@ -14,6 +15,7 @@ def _concat_messages_array(messages): def _query_genie_as_agent(input, genie_space_id, genie_agent_name): from langchain_core.messages import AIMessage + genie = Genie(genie_space_id) message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" @@ -37,7 +39,9 @@ def create_genie_agent(genie_space_id, genie_agent_name="Genie"): from langchain_core.runnables import RunnableLambda # Create a partial function with the genie_space_id pre-filled - partial_genie_agent = partial(_query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name) + partial_genie_agent = partial( + _query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name + ) # Use the partial function in the RunnableLambda - return RunnableLambda(partial_genie_agent) \ No newline at end of file + return RunnableLambda(partial_genie_agent) diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py index 8fe3f66..9c4fabb 100644 --- a/integrations/langchain/tests/test_genie.py +++ b/integrations/langchain/tests/test_genie.py @@ -1,23 +1,21 @@ -import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch + from langchain_core.messages import AIMessage from my_module import _concat_messages_array, _query_genie_as_agent, create_genie_agent + def test_concat_messages_array(): # Test a simple case with multiple messages messages = [ {"role": "user", "content": "What is the weather?"}, - {"role": "assistant", "content": "It is sunny."} + {"role": "assistant", "content": "It is sunny."}, ] result = _concat_messages_array(messages) expected = "user: What is the weather?\nassistant: It is sunny." assert result == expected # Test case with missing content - messages = [ - {"role": "user"}, - {"role": "assistant", "content": "I don't know."} - ] + messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}] result = _concat_messages_array(messages) expected = "user: \nassistant: I don't know." assert result == expected @@ -30,14 +28,14 @@ def __init__(self, role, content): messages = [ Message("user", "Tell me a joke."), - Message("assistant", "Why did the chicken cross the road?") + Message("assistant", "Why did the chicken cross the road?"), ] result = _concat_messages_array(messages) expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" assert result == expected -@patch('databricks_ai_bridge.genie.Genie') +@patch("databricks_ai_bridge.genie.Genie") def test_query_genie_as_agent(MockGenie): # Mock the Genie class and its response mock_genie = MockGenie.return_value @@ -46,22 +44,18 @@ def test_query_genie_as_agent(MockGenie): input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} result = _query_genie_as_agent(input_data, "space-id", "Genie") - expected_message = { - "messages": [AIMessage(content="It is sunny.")] - } + expected_message = {"messages": [AIMessage(content="It is sunny.")]} assert result == expected_message # Test the case when genie_response is empty mock_genie.ask_question.return_value = None result = _query_genie_as_agent(input_data, "space-id", "Genie") - expected_message = { - "messages": [AIMessage(content="")] - } + expected_message = {"messages": [AIMessage(content="")]} assert result == expected_message -@patch('langchain_core.runnables.RunnableLambda') +@patch("langchain_core.runnables.RunnableLambda") def test_create_genie_agent(MockRunnableLambda): mock_runnable = MockRunnableLambda.return_value @@ -69,4 +63,4 @@ def test_create_genie_agent(MockRunnableLambda): assert agent == mock_runnable # Check that the partial function is created with the correct arguments - MockRunnableLambda.assert_called() \ No newline at end of file + MockRunnableLambda.assert_called() From aa4f3d115c25b2cae255e2b14322fc1c48a0d7a8 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Thu, 24 Oct 2024 09:53:22 -0700 Subject: [PATCH 3/6] update langchain test Signed-off-by: Prithvi Kannan --- .github/workflows/main.yml | 21 +++++++++++++++++++++ integrations/langchain/pyproject.toml | 1 - integrations/langchain/tests/test_genie.py | 4 ++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 91ebd2e..94a3681 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -50,3 +50,24 @@ jobs: - name: Run tests run: | pytest tests/ + + langchain_test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10'] + timeout-minutes: 20 + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install . + pip install integrations/langchain[dev] + - name: Run tests + run: | + pytest integrations/langchain/tests diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 5d06e6f..40527ed 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -20,7 +20,6 @@ dev = [ "typing_extensions", "databricks-sdk>=0.34.0", "ruff==0.6.4", - "langgraph", ] [build-system] diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py index 9c4fabb..805be03 100644 --- a/integrations/langchain/tests/test_genie.py +++ b/integrations/langchain/tests/test_genie.py @@ -1,7 +1,7 @@ from unittest.mock import patch from langchain_core.messages import AIMessage -from my_module import _concat_messages_array, _query_genie_as_agent, create_genie_agent +from databricks_langchain.genie import _concat_messages_array, _query_genie_as_agent, create_genie_agent def test_concat_messages_array(): @@ -35,7 +35,7 @@ def __init__(self, role, content): assert result == expected -@patch("databricks_ai_bridge.genie.Genie") +@patch("databricks_langchain.genie.Genie") def test_query_genie_as_agent(MockGenie): # Mock the Genie class and its response mock_genie = MockGenie.return_value From 2879071e4db3e7ad2dc4c42540f176a238779b77 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Thu, 24 Oct 2024 09:55:57 -0700 Subject: [PATCH 4/6] pyproject Signed-off-by: Prithvi Kannan --- CONTRIBUTING.md | 2 +- pyproject.toml | 17 +++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7d00e19..3baf1f5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,6 +5,6 @@ Create a conda environement and install dev requirements ``` conda create --name databricks-ai-dev-env python=3.10 conda activate databricks-ai-dev-env -pip install -e ".[databricks-dev]" +pip install -e ".[dev]" pip install -r requirements/lint-requirements.txt ``` diff --git a/pyproject.toml b/pyproject.toml index 775d184..c1df40d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,9 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "typing_extensions", - "pydantic" + "pydantic", + "databricks-sdk>=0.34.0", + "pandas", ] [project.license] @@ -28,22 +30,9 @@ include = [ packages = ["src/databricks_ai_bridge"] [project.optional-dependencies] -databricks = [ - "databricks-sdk>=0.34.0", - "pandas", -] -databricks-dev = [ - "hatch", - "pytest", - "databricks-sdk>=0.34.0", - "pandas", - "ruff==0.6.4", -] dev = [ "hatch", "pytest", - "databricks-sdk>=0.34.0", - "pandas", "ruff==0.6.4", ] From 580c92860ac9995af6a88980ad62e14411a13647 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Thu, 24 Oct 2024 09:56:25 -0700 Subject: [PATCH 5/6] ruff Signed-off-by: Prithvi Kannan --- integrations/langchain/tests/test_genie.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py index 805be03..9e41bd3 100644 --- a/integrations/langchain/tests/test_genie.py +++ b/integrations/langchain/tests/test_genie.py @@ -1,7 +1,12 @@ from unittest.mock import patch from langchain_core.messages import AIMessage -from databricks_langchain.genie import _concat_messages_array, _query_genie_as_agent, create_genie_agent + +from databricks_langchain.genie import ( + _concat_messages_array, + _query_genie_as_agent, + create_genie_agent, +) def test_concat_messages_array(): From 776d8c1f0fe87a7270f6cc5e46488d93ecff74ce Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Thu, 24 Oct 2024 10:06:24 -0700 Subject: [PATCH 6/6] rename Signed-off-by: Prithvi Kannan --- integrations/langchain/src/databricks_langchain/genie.py | 2 +- integrations/langchain/tests/test_genie.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 0528bfe..153c2df 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -32,7 +32,7 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name): return {"messages": [AIMessage(content="")]} -def create_genie_agent(genie_space_id, genie_agent_name="Genie"): +def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""): """Create a genie agent that can be used to query the API""" from functools import partial diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py index 9e41bd3..70c6c28 100644 --- a/integrations/langchain/tests/test_genie.py +++ b/integrations/langchain/tests/test_genie.py @@ -3,9 +3,9 @@ from langchain_core.messages import AIMessage from databricks_langchain.genie import ( + GenieAgent, _concat_messages_array, _query_genie_as_agent, - create_genie_agent, ) @@ -64,7 +64,7 @@ def test_query_genie_as_agent(MockGenie): def test_create_genie_agent(MockRunnableLambda): mock_runnable = MockRunnableLambda.return_value - agent = create_genie_agent("space-id", "Genie") + agent = GenieAgent("space-id", "Genie") assert agent == mock_runnable # Check that the partial function is created with the correct arguments