Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rag): Support rag retriever evaluation #1291

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions assets/schema/dbgpt.sql
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ CREATE TABLE `dbgpt_serve_flow` (
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
`error_message` varchar(512) NULL comment 'Error message',
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
Expand Down
Empty file.
395 changes: 395 additions & 0 deletions assets/schema/upgrade/v0_5_2/v0.5.1.sql

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def get_device() -> str:
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
# https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True
"bge-m3": os.path.join(MODEL_PATH, "bge-m3"),
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CachePolicy,
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
Expand Down Expand Up @@ -103,4 +104,5 @@
"DefaultStorageItemAdapter",
"QuerySpec",
"StorageError",
"Embeddings",
]
2 changes: 2 additions & 0 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
CommonLLMHttpResponseBody,
HttpTrigger,
)
from .trigger.iterator_trigger import IteratorTrigger

_request_http_trigger_available = False
try:
Expand Down Expand Up @@ -100,6 +101,7 @@
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"IteratorTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
"CommonLLMHttpRequestBody",
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
return task_output


class TriggerOperator(InputOperator, Generic[OUT]):
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
"""Operator node that triggers the DAG to run."""

def __init__(self, **kwargs) -> None:
Expand Down
8 changes: 4 additions & 4 deletions dbgpt/core/awel/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ async def execute_workflow(
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
)
if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
# if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
Expand All @@ -76,8 +76,8 @@ async def execute_workflow(
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
# if node.dag:
# del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx

async def _execute_node(
Expand Down
39 changes: 39 additions & 0 deletions dbgpt/core/awel/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from enum import Enum
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
Expand Down Expand Up @@ -421,3 +423,40 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
Returns:
TaskOutput[T]: The output object read from current source
"""

@classmethod
def from_data(cls, data: T) -> "InputSource[T]":
"""Create an InputSource from data.

Args:
data (T): The data to create the InputSource from.

Returns:
InputSource[T]: The InputSource created from the data.
"""
from .task_impl import SimpleInputSource

return SimpleInputSource(data, streaming=False)

@classmethod
def from_iterable(
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
) -> "InputSource[T]":
"""Create an InputSource from an iterable.

Args:
iterable (List[T]): The iterable to create the InputSource from.

Returns:
InputSource[T]: The InputSource created from the iterable.
"""
from .task_impl import SimpleInputSource

return SimpleInputSource(iterable, streaming=True)

@classmethod
def from_callable(cls) -> "InputSource[T]":
"""Create an InputSource from a callable."""
from .task_impl import SimpleCallDataInputSource

return SimpleCallDataInputSource()
44 changes: 39 additions & 5 deletions dbgpt/core/awel/task/task_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,42 @@ def _is_async_iterator(obj):
)


def _is_async_iterable(obj):
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))


def _is_iterator(obj):
return (
hasattr(obj, "__iter__")
and callable(getattr(obj, "__iter__", None))
and hasattr(obj, "__next__")
and callable(getattr(obj, "__next__", None))
)


def _is_iterable(obj):
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))


async def _to_async_iterator(obj) -> AsyncIterator:
if _is_async_iterable(obj):
async for item in obj:
yield item
elif _is_iterable(obj):
for item in obj:
yield item
else:
raise ValueError(f"Can not convert {obj} to AsyncIterator")


class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""

def __init__(self) -> None:
def __init__(self, streaming: Optional[bool] = None) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
self._streaming_data = streaming

@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
Expand All @@ -286,10 +315,15 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput:
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._streaming_data is None:
streaming_data = _is_async_iterator(data) or _is_iterator(data)
else:
streaming_data = self._streaming_data
if streaming_data:
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output: TaskOutput = SimpleStreamTaskOutput(data)
it_data = _to_async_iterator(data)
output: TaskOutput = SimpleStreamTaskOutput(it_data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
Expand All @@ -299,13 +333,13 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput:
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""

def __init__(self, data: Any) -> None:
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
"""Create a SimpleInputSource.

Args:
data (Any): The input data.
"""
super().__init__()
super().__init__(streaming=streaming)
self._data = data

def _read_data(self, task_ctx: TaskContext) -> Any:
Expand Down
Empty file.
118 changes: 118 additions & 0 deletions dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import AsyncIterator

import pytest

from dbgpt.core.awel import (
DAG,
InputSource,
MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger


class NumberProducerOperator(StreamifyAbsOperator[int, int]):
"""Create a stream of numbers from 0 to `n-1`"""

async def streamify(self, n: int) -> AsyncIterator[int]:
for i in range(n):
yield i


class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in data:
yield i * i


async def _check_stream_results(stream_results, expected_len):
assert len(stream_results) == expected_len
for _, result in stream_results:
i = 0
async for num in result:
assert num == i * i
i += 1


@pytest.mark.asyncio
async def test_single_data():
with DAG("test_single_data"):
trigger_task = IteratorTrigger(data=2)
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 1
assert results[0][1] == 4

with DAG("test_single_data_stream"):
trigger_task = IteratorTrigger(data=2, streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 1)


@pytest.mark.asyncio
async def test_list_data():
with DAG("test_list_data"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]

with DAG("test_list_data_stream"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)


@pytest.mark.asyncio
async def test_async_iterator_data():
async def async_iter():
for i in range(4):
yield i

with DAG("test_async_iterator_data"):
trigger_task = IteratorTrigger(data=async_iter())
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]

with DAG("test_async_iterator_data_stream"):
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)


@pytest.mark.asyncio
async def test_input_source_data():
with DAG("test_input_source_data"):
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]

with DAG("test_input_source_data_stream"):
trigger_task = IteratorTrigger(
data=InputSource.from_iterable([0, 1, 2, 3]),
streaming_call=True,
)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)
6 changes: 4 additions & 2 deletions dbgpt/core/awel/trigger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Generic

from ..operators.common_operator import TriggerOperator
from ..task.base import OUT


class Trigger(TriggerOperator, ABC):
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
"""Base class for all trigger classes.

Now only support http trigger.
"""

@abstractmethod
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the workflow or a specific operation in the workflow."""
4 changes: 2 additions & 2 deletions dbgpt/core/awel/trigger/http_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def __init__(
self._end_node: Optional[BaseOperator] = None
self._register_to_app = register_to_app

async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
raise NotImplementedError("HttpTrigger does not support trigger directly")

def register_to_app(self) -> bool:
"""Register the trigger to a FastAPI app.
Expand Down
Loading
Loading