-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new Databricks Vector Search langchain native tool VectorSearchRetrieverTool #24
base: main
Are you sure you want to change the base?
Changes from all commits
fc34a39
1e1798e
d662b46
9d32449
60a5eb5
fe09b25
2105f55
5b5361a
66fe9a7
9166b5a
53f9924
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Any, Dict, List, Optional, Type | ||
|
||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.tools import BaseTool | ||
from pydantic import BaseModel, Field, PrivateAttr, model_validator | ||
|
||
from databricks_langchain.utils import IndexDetails | ||
from databricks_langchain.vectorstores import DatabricksVectorSearch | ||
|
||
|
||
class VectorSearchRetrieverToolInput(BaseModel): | ||
query: str = Field( | ||
description="The string used to query the index with and identify the most similar " | ||
"vectors and return the associated documents." | ||
) | ||
|
||
|
||
class VectorSearchRetrieverTool(BaseTool): | ||
""" | ||
A utility class to create a vector search-based retrieval tool for querying indexed embeddings. | ||
This class integrates with a Databricks Vector Search and provides a convenient interface | ||
for building a retriever tool for agents. | ||
""" | ||
|
||
index_name: str = Field( | ||
..., description="The name of the index to use, format: 'catalog.schema.index'." | ||
) | ||
num_results: int = Field(10, description="The number of results to return.") | ||
columns: Optional[List[str]] = Field( | ||
None, description="Columns to return when doing the search." | ||
) | ||
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.") | ||
query_type: str = Field( | ||
"ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'." | ||
) | ||
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.") | ||
tool_description: Optional[str] = Field(None, description="A description of the tool.") | ||
text_column: Optional[str] = Field( | ||
None, | ||
description="The name of the text column to use for the embeddings. " | ||
"Required for direct-access index or delta-sync index with " | ||
"self-managed embeddings.", | ||
) | ||
embedding: Optional[Embeddings] = Field( | ||
None, description="Embedding model for self-managed embeddings." | ||
) | ||
|
||
# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() | ||
name: str = Field(default="", description="The name of the tool") | ||
description: str = Field(default="", description="The description of the tool") | ||
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput | ||
|
||
_vector_store: DatabricksVectorSearch = PrivateAttr() | ||
|
||
@model_validator(mode="after") | ||
def validate_tool_inputs(self): | ||
kwargs = { | ||
"index_name": self.index_name, | ||
"embedding": self.embedding, | ||
"text_column": self.text_column, | ||
"columns": self.columns, | ||
} | ||
dbvs = DatabricksVectorSearch(**kwargs) | ||
self._vector_store = dbvs | ||
|
||
def get_tool_description(): | ||
default_tool_description = ( | ||
"A vector search-based retrieval tool for querying indexed embeddings." | ||
) | ||
index_details = IndexDetails(dbvs.index) | ||
if index_details.is_delta_sync_index(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. direct access indexes don't have an associated source table so we'll just use the default tool description. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious what the existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One way to tell would be to use it as a tool with payload logging enabled & see what the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This generally looks reasonable, just curious if we can keep it in sync with the existing behavior/default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
from databricks.sdk import WorkspaceClient | ||
|
||
source_table = index_details.index_spec.get("source_table", "") | ||
w = WorkspaceClient() | ||
source_table_comment = w.tables.get(full_name=source_table).comment | ||
if source_table_comment: | ||
return ( | ||
default_tool_description | ||
+ f" The queried index uses the source table {source_table} with the description: " | ||
+ source_table_comment | ||
) | ||
else: | ||
return ( | ||
default_tool_description | ||
+ f" The queried index uses the source table {source_table}" | ||
) | ||
return default_tool_description | ||
|
||
self.name = self.tool_name or self.index_name | ||
self.description = self.tool_description or get_tool_description() | ||
|
||
return self | ||
|
||
def _run(self, query: str) -> str: | ||
return self._vector_store.similarity_search( | ||
query, k=self.num_results, filter=self.filters, query_type=self.query_type | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ, does this get sent to the LLM as the parameter description? If so I wonder if it's worth including examples like the ones in https://docs.databricks.com/api/workspace/vectorsearchindexes/queryindex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nvm, this is in the init, not in the tool call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But seems like there is a way we can specify the description of the params for the LLM too: https://chatgpt.com/share/6764d76f-69a0-8009-8a8f-f58977753057
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See also https://python.langchain.com/docs/how_to/custom_tools/#subclass-basetool (we can use
args_schema
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to include VectorSearchRetrieverToolInput as an args_schema