Skip to content

Commit

Permalink
fix: Fixed multi-turn dialogue bug (#1259)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Mar 6, 2024
1 parent 74ec8e5 commit 872b574
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 35 deletions.
7 changes: 7 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ KNOWLEDGE_SEARCH_REWRITE=False
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002

## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
# proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924
# proxy_http_openapi_proxy_backend=text2vec



#*******************************************************************#
#** DB-GPT METADATA DATABASE SETTINGS **#
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def get_device() -> str:
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
# Common HTTP embedding model
"proxy_http_openapi": "proxy_http_openapi",
}


Expand Down
46 changes: 16 additions & 30 deletions dbgpt/core/interface/operators/message_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,13 @@ def __init__(
):
"""Create a new BufferedConversationMapperOperator."""
# Validate the input parameters
if keep_start_rounds is not None and keep_start_rounds < 0:
if keep_start_rounds is None:
keep_start_rounds = 0
if keep_end_rounds is None:
keep_end_rounds = 0
if keep_start_rounds < 0:
raise ValueError("keep_start_rounds must be non-negative")
if keep_end_rounds is not None and keep_end_rounds < 0:
if keep_end_rounds < 0:
raise ValueError("keep_end_rounds must be non-negative")

self._keep_start_rounds = keep_start_rounds
Expand Down Expand Up @@ -420,7 +424,7 @@ def _filter_round_messages(
... ],
... ]
# Test keeping only the first 2 rounds
>>> # Test keeping only the first 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
Expand All @@ -433,7 +437,7 @@ def _filter_round_messages(
... ],
... ]
# Test keeping only the last 2 rounds
>>> # Test keeping only the last 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
Expand All @@ -446,7 +450,7 @@ def _filter_round_messages(
... ],
... ]
# Test keeping the first 2 and last 1 rounds
>>> # Test keeping the first 2 and last 1 rounds
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=2, keep_end_rounds=1
... )
Expand All @@ -465,24 +469,11 @@ def _filter_round_messages(
... ],
... ]
# Test without specifying start or end rounds (keep all rounds)
>>> # Test without specifying start or end rounds (keep 0 rounds)
>>> operator = BufferedConversationMapperOperator()
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
>>> assert operator._filter_round_messages(messages) == []
# Test end rounds is zero
>>> # Test end rounds is zero
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=1, keep_end_rounds=0
... )
Expand All @@ -503,12 +494,7 @@ def _filter_round_messages(
"""
total_rounds = len(messages_by_round)
if (
self._keep_start_rounds is not None
and self._keep_end_rounds is not None
and self._keep_start_rounds > 0
and self._keep_end_rounds > 0
):
if self._keep_start_rounds > 0 and self._keep_end_rounds > 0:
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
# Avoid overlapping when the sum of start and end rounds exceeds total
# rounds
Expand All @@ -517,12 +503,12 @@ def _filter_round_messages(
messages_by_round[: self._keep_start_rounds]
+ messages_by_round[-self._keep_end_rounds :]
)
elif self._keep_start_rounds is not None:
elif self._keep_start_rounds:
return messages_by_round[: self._keep_start_rounds]
elif self._keep_end_rounds is not None:
elif self._keep_end_rounds:
return messages_by_round[-self._keep_end_rounds :]
else:
return messages_by_round
return []


EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]
Expand Down
Empty file.
155 changes: 155 additions & 0 deletions dbgpt/core/interface/operators/tests/test_message_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import List

import pytest

from dbgpt.core.interface.message import AIMessage, BaseMessage, HumanMessage
from dbgpt.core.operators import BufferedConversationMapperOperator


@pytest.fixture
def messages() -> List[BaseMessage]:
return [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]


@pytest.mark.asyncio
async def test_buffered_conversation_keep_start_rounds(messages: List[BaseMessage]):
# Test keep_start_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
]
# Test keep start 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == []

# Test keep start 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=100,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == messages

# Test keep start -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=-1,
keep_end_rounds=None,
)
await operator.map_messages(messages)


@pytest.mark.asyncio
async def test_buffered_conversation_keep_end_rounds(messages: List[BaseMessage]):
# Test keep_end_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=2,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == []

# Test keep end 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=100,
)
assert await operator.map_messages(messages) == messages

# Test keep end -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=-1,
)
await operator.map_messages(messages)


@pytest.mark.asyncio
async def test_buffered_conversation_keep_start_end_rounds(messages: List[BaseMessage]):
# Test keep_start_rounds and keep_end_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=1,
keep_end_rounds=1,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep start 0 rounds and keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == []

# Test keep start 0 rounds and keep end 1 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=1,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]

# Test keep start 2 rounds and keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
]

# Test keep start 100 rounds and keep end 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=100,
keep_end_rounds=100,
)
assert await operator.map_messages(messages) == messages

# Test keep start 2 round and keep end 2 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=2,
)
assert await operator.map_messages(messages) == messages

# Test keep start -1 rounds and keep end -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=-1,
keep_end_rounds=-1,
)
await operator.map_messages(messages)
22 changes: 18 additions & 4 deletions dbgpt/model/cluster/embedding/loader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, cast

from dbgpt.model.parameter import BaseEmbeddingModelParameters
from dbgpt.model.parameter import BaseEmbeddingModelParameters, ProxyEmbeddingParameters
from dbgpt.util.parameter_utils import _get_dict_from_obj
from dbgpt.util.system_utils import get_system_info
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer

if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
from langchain.embeddings.base import Embeddings as LangChainEmbeddings

from dbgpt.rag.embedding import Embeddings


class EmbeddingLoader:
Expand All @@ -17,7 +19,7 @@ def __init__(self) -> None:

def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
) -> "Union[LangChainEmbeddings, Embeddings]":
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
Expand All @@ -32,6 +34,18 @@ def load(
from langchain.embeddings import OpenAIEmbeddings

return OpenAIEmbeddings(**param.build_kwargs())
elif model_name in ["proxy_http_openapi"]:
from dbgpt.rag.embedding import OpenAPIEmbeddings

proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
else:
from langchain.embeddings import HuggingFaceEmbeddings

Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def build_kwargs(self, **kwargs) -> Dict:


_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi",
}

EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
Expand Down

0 comments on commit 872b574

Please sign in to comment.