diff --git a/.env.template b/.env.template index ba5a4bb51..13f9e574c 100644 --- a/.env.template +++ b/.env.template @@ -96,6 +96,13 @@ KNOWLEDGE_SEARCH_REWRITE=False # proxy_openai_proxy_api_key={your-openai-sk} # proxy_openai_proxy_backend=text-embedding-ada-002 +## Common HTTP embedding model +# EMBEDDING_MODEL=proxy_http_openapi +# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings +# proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924 +# proxy_http_openapi_proxy_backend=text2vec + + #*******************************************************************# #** DB-GPT METADATA DATABASE SETTINGS **# diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index c4d70ec4c..d6d60b8b2 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -172,6 +172,8 @@ def get_device() -> str: "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "proxy_openai": "proxy_openai", "proxy_azure": "proxy_azure", + # Common HTTP embedding model + "proxy_http_openapi": "proxy_http_openapi", } diff --git a/dbgpt/core/interface/operators/message_operator.py b/dbgpt/core/interface/operators/message_operator.py index 4018a83bb..a778c41bd 100644 --- a/dbgpt/core/interface/operators/message_operator.py +++ b/dbgpt/core/interface/operators/message_operator.py @@ -370,9 +370,13 @@ def __init__( ): """Create a new BufferedConversationMapperOperator.""" # Validate the input parameters - if keep_start_rounds is not None and keep_start_rounds < 0: + if keep_start_rounds is None: + keep_start_rounds = 0 + if keep_end_rounds is None: + keep_end_rounds = 0 + if keep_start_rounds < 0: raise ValueError("keep_start_rounds must be non-negative") - if keep_end_rounds is not None and keep_end_rounds < 0: + if keep_end_rounds < 0: raise ValueError("keep_end_rounds must be non-negative") self._keep_start_rounds = keep_start_rounds @@ -420,7 +424,7 @@ def _filter_round_messages( ... ], ... ] - # Test keeping only the first 2 rounds + >>> # Test keeping only the first 2 rounds >>> operator = BufferedConversationMapperOperator(keep_start_rounds=2) >>> assert operator._filter_round_messages(messages) == [ ... [ @@ -433,7 +437,7 @@ def _filter_round_messages( ... ], ... ] - # Test keeping only the last 2 rounds + >>> # Test keeping only the last 2 rounds >>> operator = BufferedConversationMapperOperator(keep_end_rounds=2) >>> assert operator._filter_round_messages(messages) == [ ... [ @@ -446,7 +450,7 @@ def _filter_round_messages( ... ], ... ] - # Test keeping the first 2 and last 1 rounds + >>> # Test keeping the first 2 and last 1 rounds >>> operator = BufferedConversationMapperOperator( ... keep_start_rounds=2, keep_end_rounds=1 ... ) @@ -465,24 +469,11 @@ def _filter_round_messages( ... ], ... ] - # Test without specifying start or end rounds (keep all rounds) + >>> # Test without specifying start or end rounds (keep 0 rounds) >>> operator = BufferedConversationMapperOperator() - >>> assert operator._filter_round_messages(messages) == [ - ... [ - ... HumanMessage(content="Hi", round_index=1), - ... AIMessage(content="Hello!", round_index=1), - ... ], - ... [ - ... HumanMessage(content="How are you?", round_index=2), - ... AIMessage(content="I'm good, thanks!", round_index=2), - ... ], - ... [ - ... HumanMessage(content="What's new today?", round_index=3), - ... AIMessage(content="Lots of things!", round_index=3), - ... ], - ... ] + >>> assert operator._filter_round_messages(messages) == [] - # Test end rounds is zero + >>> # Test end rounds is zero >>> operator = BufferedConversationMapperOperator( ... keep_start_rounds=1, keep_end_rounds=0 ... ) @@ -503,12 +494,7 @@ def _filter_round_messages( """ total_rounds = len(messages_by_round) - if ( - self._keep_start_rounds is not None - and self._keep_end_rounds is not None - and self._keep_start_rounds > 0 - and self._keep_end_rounds > 0 - ): + if self._keep_start_rounds > 0 and self._keep_end_rounds > 0: if self._keep_start_rounds + self._keep_end_rounds > total_rounds: # Avoid overlapping when the sum of start and end rounds exceeds total # rounds @@ -517,12 +503,12 @@ def _filter_round_messages( messages_by_round[: self._keep_start_rounds] + messages_by_round[-self._keep_end_rounds :] ) - elif self._keep_start_rounds is not None: + elif self._keep_start_rounds: return messages_by_round[: self._keep_start_rounds] - elif self._keep_end_rounds is not None: + elif self._keep_end_rounds: return messages_by_round[-self._keep_end_rounds :] else: - return messages_by_round + return [] EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]] diff --git a/dbgpt/core/interface/operators/tests/__init__.py b/dbgpt/core/interface/operators/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/core/interface/operators/tests/test_message_operator.py b/dbgpt/core/interface/operators/tests/test_message_operator.py new file mode 100644 index 000000000..5e146a0f8 --- /dev/null +++ b/dbgpt/core/interface/operators/tests/test_message_operator.py @@ -0,0 +1,155 @@ +from typing import List + +import pytest + +from dbgpt.core.interface.message import AIMessage, BaseMessage, HumanMessage +from dbgpt.core.operators import BufferedConversationMapperOperator + + +@pytest.fixture +def messages() -> List[BaseMessage]: + return [ + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] + + +@pytest.mark.asyncio +async def test_buffered_conversation_keep_start_rounds(messages: List[BaseMessage]): + # Test keep_start_rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=2, + keep_end_rounds=None, + ) + assert await operator.map_messages(messages) == [ + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + ] + # Test keep start 0 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=0, + keep_end_rounds=None, + ) + assert await operator.map_messages(messages) == [] + + # Test keep start 100 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=100, + keep_end_rounds=None, + ) + assert await operator.map_messages(messages) == messages + + # Test keep start -1 rounds + with pytest.raises(ValueError): + operator = BufferedConversationMapperOperator( + keep_start_rounds=-1, + keep_end_rounds=None, + ) + await operator.map_messages(messages) + + +@pytest.mark.asyncio +async def test_buffered_conversation_keep_end_rounds(messages: List[BaseMessage]): + # Test keep_end_rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=None, + keep_end_rounds=2, + ) + assert await operator.map_messages(messages) == [ + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] + # Test keep end 0 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=0, + keep_end_rounds=0, + ) + assert await operator.map_messages(messages) == [] + + # Test keep end 100 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=None, + keep_end_rounds=100, + ) + assert await operator.map_messages(messages) == messages + + # Test keep end -1 rounds + with pytest.raises(ValueError): + operator = BufferedConversationMapperOperator( + keep_start_rounds=None, + keep_end_rounds=-1, + ) + await operator.map_messages(messages) + + +@pytest.mark.asyncio +async def test_buffered_conversation_keep_start_end_rounds(messages: List[BaseMessage]): + # Test keep_start_rounds and keep_end_rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=1, + keep_end_rounds=1, + ) + assert await operator.map_messages(messages) == [ + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] + # Test keep start 0 rounds and keep end 0 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=0, + keep_end_rounds=0, + ) + assert await operator.map_messages(messages) == [] + + # Test keep start 0 rounds and keep end 1 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=0, + keep_end_rounds=1, + ) + assert await operator.map_messages(messages) == [ + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] + + # Test keep start 2 rounds and keep end 0 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=2, + keep_end_rounds=0, + ) + assert await operator.map_messages(messages) == [ + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + ] + + # Test keep start 100 rounds and keep end 100 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=100, + keep_end_rounds=100, + ) + assert await operator.map_messages(messages) == messages + + # Test keep start 2 round and keep end 2 rounds + operator = BufferedConversationMapperOperator( + keep_start_rounds=2, + keep_end_rounds=2, + ) + assert await operator.map_messages(messages) == messages + + # Test keep start -1 rounds and keep end -1 rounds + with pytest.raises(ValueError): + operator = BufferedConversationMapperOperator( + keep_start_rounds=-1, + keep_end_rounds=-1, + ) + await operator.map_messages(messages) diff --git a/dbgpt/model/cluster/embedding/loader.py b/dbgpt/model/cluster/embedding/loader.py index b4fba70b9..2bb1611b5 100644 --- a/dbgpt/model/cluster/embedding/loader.py +++ b/dbgpt/model/cluster/embedding/loader.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union, cast -from dbgpt.model.parameter import BaseEmbeddingModelParameters +from dbgpt.model.parameter import BaseEmbeddingModelParameters, ProxyEmbeddingParameters from dbgpt.util.parameter_utils import _get_dict_from_obj from dbgpt.util.system_utils import get_system_info from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer if TYPE_CHECKING: - from langchain.embeddings.base import Embeddings + from langchain.embeddings.base import Embeddings as LangChainEmbeddings + + from dbgpt.rag.embedding import Embeddings class EmbeddingLoader: @@ -17,7 +19,7 @@ def __init__(self) -> None: def load( self, model_name: str, param: BaseEmbeddingModelParameters - ) -> "Embeddings": + ) -> "Union[LangChainEmbeddings, Embeddings]": metadata = { "model_name": model_name, "run_service": SpanTypeRunName.EMBEDDING_MODEL.value, @@ -32,6 +34,18 @@ def load( from langchain.embeddings import OpenAIEmbeddings return OpenAIEmbeddings(**param.build_kwargs()) + elif model_name in ["proxy_http_openapi"]: + from dbgpt.rag.embedding import OpenAPIEmbeddings + + proxy_param = cast(ProxyEmbeddingParameters, param) + openapi_param = {} + if proxy_param.proxy_server_url: + openapi_param["api_url"] = proxy_param.proxy_server_url + if proxy_param.proxy_api_key: + openapi_param["api_key"] = proxy_param.proxy_api_key + if proxy_param.proxy_backend: + openapi_param["model_name"] = proxy_param.proxy_backend + return OpenAPIEmbeddings(**openapi_param) else: from langchain.embeddings import HuggingFaceEmbeddings diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 946d65ecf..3dc68afcc 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -552,7 +552,7 @@ def build_kwargs(self, **kwargs) -> Dict: _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { - ProxyEmbeddingParameters: "proxy_openai,proxy_azure" + ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi", } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}