Skip to content

Commit

Permalink
feat(rag): Support rag retriever evaluation (#1291)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Mar 14, 2024
1 parent cd2dcc2 commit adaa68e
Show file tree
Hide file tree
Showing 34 changed files with 1,452 additions and 67 deletions.
1 change: 1 addition & 0 deletions assets/schema/dbgpt.sql
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ CREATE TABLE `dbgpt_serve_flow` (
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
`error_message` varchar(512) NULL comment 'Error message',
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
Expand Down
Empty file.
395 changes: 395 additions & 0 deletions assets/schema/upgrade/v0_5_2/v0.5.1.sql

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def get_device() -> str:
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
# https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True
"bge-m3": os.path.join(MODEL_PATH, "bge-m3"),
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CachePolicy,
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
Expand Down Expand Up @@ -103,4 +104,5 @@
"DefaultStorageItemAdapter",
"QuerySpec",
"StorageError",
"Embeddings",
]
2 changes: 2 additions & 0 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
CommonLLMHttpResponseBody,
HttpTrigger,
)
from .trigger.iterator_trigger import IteratorTrigger

_request_http_trigger_available = False
try:
Expand Down Expand Up @@ -100,6 +101,7 @@
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"IteratorTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
"CommonLLMHttpRequestBody",
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
return task_output


class TriggerOperator(InputOperator, Generic[OUT]):
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
"""Operator node that triggers the DAG to run."""

def __init__(self, **kwargs) -> None:
Expand Down
8 changes: 4 additions & 4 deletions dbgpt/core/awel/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ async def execute_workflow(
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
)
if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
# if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
Expand All @@ -76,8 +76,8 @@ async def execute_workflow(
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
# if node.dag:
# del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx

async def _execute_node(
Expand Down
39 changes: 39 additions & 0 deletions dbgpt/core/awel/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from enum import Enum
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
Expand Down Expand Up @@ -421,3 +423,40 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
Returns:
TaskOutput[T]: The output object read from current source
"""

@classmethod
def from_data(cls, data: T) -> "InputSource[T]":
"""Create an InputSource from data.
Args:
data (T): The data to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the data.
"""
from .task_impl import SimpleInputSource

return SimpleInputSource(data, streaming=False)

@classmethod
def from_iterable(
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
) -> "InputSource[T]":
"""Create an InputSource from an iterable.
Args:
iterable (List[T]): The iterable to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the iterable.
"""
from .task_impl import SimpleInputSource

return SimpleInputSource(iterable, streaming=True)

@classmethod
def from_callable(cls) -> "InputSource[T]":
"""Create an InputSource from a callable."""
from .task_impl import SimpleCallDataInputSource

return SimpleCallDataInputSource()
44 changes: 39 additions & 5 deletions dbgpt/core/awel/task/task_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,42 @@ def _is_async_iterator(obj):
)


def _is_async_iterable(obj):
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))


def _is_iterator(obj):
return (
hasattr(obj, "__iter__")
and callable(getattr(obj, "__iter__", None))
and hasattr(obj, "__next__")
and callable(getattr(obj, "__next__", None))
)


def _is_iterable(obj):
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))


async def _to_async_iterator(obj) -> AsyncIterator:
if _is_async_iterable(obj):
async for item in obj:
yield item
elif _is_iterable(obj):
for item in obj:
yield item
else:
raise ValueError(f"Can not convert {obj} to AsyncIterator")


class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""

def __init__(self) -> None:
def __init__(self, streaming: Optional[bool] = None) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
self._streaming_data = streaming

@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
Expand All @@ -286,10 +315,15 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput:
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._streaming_data is None:
streaming_data = _is_async_iterator(data) or _is_iterator(data)
else:
streaming_data = self._streaming_data
if streaming_data:
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output: TaskOutput = SimpleStreamTaskOutput(data)
it_data = _to_async_iterator(data)
output: TaskOutput = SimpleStreamTaskOutput(it_data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
Expand All @@ -299,13 +333,13 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput:
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""

def __init__(self, data: Any) -> None:
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
super().__init__(streaming=streaming)
self._data = data

def _read_data(self, task_ctx: TaskContext) -> Any:
Expand Down
Empty file.
118 changes: 118 additions & 0 deletions dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import AsyncIterator

import pytest

from dbgpt.core.awel import (
DAG,
InputSource,
MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger


class NumberProducerOperator(StreamifyAbsOperator[int, int]):
"""Create a stream of numbers from 0 to `n-1`"""

async def streamify(self, n: int) -> AsyncIterator[int]:
for i in range(n):
yield i


class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in data:
yield i * i


async def _check_stream_results(stream_results, expected_len):
assert len(stream_results) == expected_len
for _, result in stream_results:
i = 0
async for num in result:
assert num == i * i
i += 1


@pytest.mark.asyncio
async def test_single_data():
with DAG("test_single_data"):
trigger_task = IteratorTrigger(data=2)
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 1
assert results[0][1] == 4

with DAG("test_single_data_stream"):
trigger_task = IteratorTrigger(data=2, streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 1)


@pytest.mark.asyncio
async def test_list_data():
with DAG("test_list_data"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]

with DAG("test_list_data_stream"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)


@pytest.mark.asyncio
async def test_async_iterator_data():
async def async_iter():
for i in range(4):
yield i

with DAG("test_async_iterator_data"):
trigger_task = IteratorTrigger(data=async_iter())
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]

with DAG("test_async_iterator_data_stream"):
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)


@pytest.mark.asyncio
async def test_input_source_data():
with DAG("test_input_source_data"):
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]

with DAG("test_input_source_data_stream"):
trigger_task = IteratorTrigger(
data=InputSource.from_iterable([0, 1, 2, 3]),
streaming_call=True,
)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)
6 changes: 4 additions & 2 deletions dbgpt/core/awel/trigger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Generic

from ..operators.common_operator import TriggerOperator
from ..task.base import OUT


class Trigger(TriggerOperator, ABC):
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
"""Base class for all trigger classes.
Now only support http trigger.
"""

@abstractmethod
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the workflow or a specific operation in the workflow."""
4 changes: 2 additions & 2 deletions dbgpt/core/awel/trigger/http_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def __init__(
self._end_node: Optional[BaseOperator] = None
self._register_to_app = register_to_app

async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
raise NotImplementedError("HttpTrigger does not support trigger directly")

def register_to_app(self) -> bool:
"""Register the trigger to a FastAPI app.
Expand Down
Loading

0 comments on commit adaa68e

Please sign in to comment.