-
-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Refactor API server with streaming and passage return (#810)
* add deploy package instead of deploy.py and make OpenAPI spec for new API server * moved from deploy.py * make /v1/run api * add data format extra columns * add stream async generator functions at the generators * refactor api server to use extract_retrieve_passage for simplicity * checkpoint before implementing quart * streaming working!!! * edit documentation and fix error at GradioRunner * resolve test errors * refactor API endpoint docs --------- Co-authored-by: jeffrey <[email protected]>
Showing
22 changed files
with
897 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .base import ( | ||
extract_node_line_names, | ||
extract_node_strategy, | ||
summary_df_to_yaml, | ||
extract_best_config, | ||
Runner, | ||
) | ||
from .api import ApiRunner | ||
from .gradio import GradioRunner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
import logging | ||
import os | ||
import pathlib | ||
import uuid | ||
from typing import Dict, Optional, List, Union | ||
|
||
import pandas as pd | ||
from quart import Quart, request, jsonify | ||
from quart.helpers import stream_with_context | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from autorag.deploy.base import BaseRunner | ||
from autorag.nodes.generator.base import BaseGenerator | ||
from autorag.utils import fetch_contents | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
deploy_dir = pathlib.Path(__file__).parent | ||
root_dir = pathlib.Path(__file__).parent.parent | ||
|
||
VERSION_PATH = os.path.join(root_dir, "VERSION") | ||
|
||
|
||
class QueryRequest(BaseModel): | ||
query: str | ||
result_column: Optional[str] = "generated_texts" | ||
|
||
|
||
class RetrievedPassage(BaseModel): | ||
content: str | ||
doc_id: str | ||
filepath: Optional[str] = None | ||
file_page: Optional[int] = None | ||
start_idx: Optional[int] = None | ||
end_idx: Optional[int] = None | ||
|
||
|
||
class RunResponse(BaseModel): | ||
result: Union[str, List[str]] | ||
retrieved_passage: List[RetrievedPassage] | ||
|
||
|
||
class VersionResponse(BaseModel): | ||
version: str | ||
|
||
|
||
empty_retrieved_passage = RetrievedPassage( | ||
content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None | ||
) | ||
|
||
|
||
class ApiRunner(BaseRunner): | ||
def __init__(self, config: Dict, project_dir: Optional[str] = None): | ||
super().__init__(config, project_dir) | ||
self.app = Quart(__name__) | ||
|
||
data_dir = os.path.join(project_dir, "data") | ||
self.corpus_df = pd.read_parquet( | ||
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow" | ||
) | ||
self.__add_api_route() | ||
|
||
def __add_api_route(self): | ||
@self.app.route("/v1/run", methods=["POST"]) | ||
async def run_query(): | ||
try: | ||
data = await request.get_json() | ||
data = QueryRequest(**data) | ||
except ValidationError as e: | ||
return jsonify(e.errors()), 400 | ||
|
||
previous_result = pd.DataFrame( | ||
{ | ||
"qid": str(uuid.uuid4()), | ||
"query": [data.query], | ||
"retrieval_gt": [[]], | ||
"generation_gt": [""], | ||
} | ||
) # pseudo qa data for execution | ||
for module_instance, module_param in zip( | ||
self.module_instances, self.module_params | ||
): | ||
new_result = module_instance.pure( | ||
previous_result=previous_result, **module_param | ||
) | ||
duplicated_columns = previous_result.columns.intersection( | ||
new_result.columns | ||
) | ||
drop_previous_result = previous_result.drop(columns=duplicated_columns) | ||
previous_result = pd.concat([drop_previous_result, new_result], axis=1) | ||
|
||
# Simulate processing the query | ||
generated_text = previous_result[data.result_column].tolist()[0] | ||
retrieved_passage = self.extract_retrieve_passage(previous_result) | ||
|
||
response = RunResponse( | ||
result=generated_text, retrieved_passage=retrieved_passage | ||
) | ||
|
||
return jsonify(response.model_dump()), 200 | ||
|
||
@self.app.route("/v1/stream", methods=["POST"]) | ||
async def stream_query(): | ||
try: | ||
data = await request.get_json() | ||
data = QueryRequest(**data) | ||
except ValidationError as e: | ||
return jsonify(e.errors()), 400 | ||
|
||
@stream_with_context | ||
async def generate(): | ||
previous_result = pd.DataFrame( | ||
{ | ||
"qid": str(uuid.uuid4()), | ||
"query": [data.query], | ||
"retrieval_gt": [[]], | ||
"generation_gt": [""], | ||
} | ||
) # pseudo qa data for execution | ||
|
||
for module_instance, module_param in zip( | ||
self.module_instances, self.module_params | ||
): | ||
if not isinstance(module_instance, BaseGenerator): | ||
new_result = module_instance.pure( | ||
previous_result=previous_result, **module_param | ||
) | ||
duplicated_columns = previous_result.columns.intersection( | ||
new_result.columns | ||
) | ||
drop_previous_result = previous_result.drop( | ||
columns=duplicated_columns | ||
) | ||
previous_result = pd.concat( | ||
[drop_previous_result, new_result], axis=1 | ||
) | ||
else: | ||
retrieved_passages = self.extract_retrieve_passage( | ||
previous_result | ||
) | ||
response = RunResponse( | ||
result="", retrieved_passage=retrieved_passages | ||
) | ||
yield response.model_dump_json().encode("utf-8") | ||
# Start streaming of the result | ||
assert len(previous_result) == 1 | ||
prompt: str = previous_result["prompts"].tolist()[0] | ||
async for delta in module_instance.stream( | ||
prompt=prompt, **module_param | ||
): | ||
response = RunResponse( | ||
result=delta, | ||
retrieved_passage=[empty_retrieved_passage], | ||
) | ||
yield response.model_dump_json().encode("utf-8") | ||
|
||
return generate(), 200, {"X-Something": "value"} | ||
|
||
@self.app.route("/version", methods=["GET"]) | ||
def get_version(): | ||
with open(VERSION_PATH, "r") as f: | ||
version = f.read().strip() | ||
response = VersionResponse(version=version) | ||
return jsonify(response.model_dump()), 200 | ||
|
||
def run_api_server(self, host: str = "0.0.0.0", port: int = 8000, **kwargs): | ||
""" | ||
Run the pipeline as api server. | ||
You can send POST request to `http://host:port/run` with json body like below: | ||
.. Code:: json | ||
{ | ||
"query": "your query", | ||
"result_column": "generated_texts" | ||
} | ||
And it returns json response like below: | ||
.. Code:: json | ||
{ | ||
"answer": "your answer" | ||
} | ||
:param host: The host of the api server. | ||
:param port: The port of the api server. | ||
:param kwargs: Other arguments for Flask app.run. | ||
""" | ||
logger.info(f"Run api server at {host}:{port}") | ||
self.app.run(host=host, port=port, **kwargs) | ||
|
||
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]: | ||
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0] | ||
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0] | ||
if "path" in self.corpus_df.columns: | ||
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[ | ||
0 | ||
] | ||
else: | ||
paths = [None] * len(retrieved_ids) | ||
metadatas = fetch_contents( | ||
self.corpus_df, [retrieved_ids], column_name="metadata" | ||
)[0] | ||
if "start_end_idx" in self.corpus_df.columns: | ||
start_end_indices = fetch_contents( | ||
self.corpus_df, [retrieved_ids], column_name="start_end_idx" | ||
)[0] | ||
else: | ||
start_end_indices = [None] * len(retrieved_ids) | ||
return list( | ||
map( | ||
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage( | ||
content=content, | ||
doc_id=doc_id, | ||
filepath=path, | ||
file_page=metadata.get("page", None), | ||
start_idx=start_end_idx[0] if start_end_idx else None, | ||
end_idx=start_end_idx[1] if start_end_idx else None, | ||
), | ||
contents, | ||
retrieved_ids, | ||
paths, | ||
metadatas, | ||
start_end_indices, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import logging | ||
import uuid | ||
|
||
import pandas as pd | ||
|
||
from autorag.deploy.base import BaseRunner | ||
|
||
import gradio as gr | ||
|
||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
class GradioRunner(BaseRunner): | ||
def run_web( | ||
self, | ||
server_name: str = "0.0.0.0", | ||
server_port: int = 7680, | ||
share: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Run web interface to interact pipeline. | ||
You can access the web interface at `http://server_name:server_port` in your browser | ||
:param server_name: The host of the web. Default is 0.0.0.0. | ||
:param server_port: The port of the web. Default is 7680. | ||
:param share: Whether to create a publicly shareable link. Default is False. | ||
:param kwargs: Other arguments for gr.ChatInterface.launch. | ||
""" | ||
|
||
logger.info(f"Run web interface at http://{server_name}:{server_port}") | ||
|
||
def get_response(message, _): | ||
return self.run(message) | ||
|
||
gr.ChatInterface( | ||
get_response, title="📚 AutoRAG", retry_btn=None, undo_btn=None | ||
).launch( | ||
server_name=server_name, server_port=server_port, share=share, **kwargs | ||
) | ||
|
||
def run(self, query: str, result_column: str = "generated_texts"): | ||
""" | ||
Run the pipeline with query. | ||
The loaded pipeline must start with a single query, | ||
so the first module of the pipeline must be `query_expansion` or `retrieval` module. | ||
:param query: The query of the user. | ||
:param result_column: The result column name for the answer. | ||
Default is `generated_texts`, which is the output of the `generation` module. | ||
:return: The result of the pipeline. | ||
""" | ||
previous_result = pd.DataFrame( | ||
{ | ||
"qid": str(uuid.uuid4()), | ||
"query": [query], | ||
"retrieval_gt": [[]], | ||
"generation_gt": [""], | ||
} | ||
) # pseudo qa data for execution | ||
for module_instance, module_param in zip( | ||
self.module_instances, self.module_params | ||
): | ||
new_result = module_instance.pure( | ||
previous_result=previous_result, **module_param | ||
) | ||
duplicated_columns = previous_result.columns.intersection( | ||
new_result.columns | ||
) | ||
drop_previous_result = previous_result.drop(columns=duplicated_columns) | ||
previous_result = pd.concat([drop_previous_result, new_result], axis=1) | ||
|
||
return previous_result[result_column].tolist()[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
openapi: 3.0.0 | ||
info: | ||
title: Example API | ||
version: 1.0.0 | ||
paths: | ||
/v1/run: | ||
post: | ||
summary: Run a query and get generated text with retrieved passages | ||
requestBody: | ||
required: true | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
properties: | ||
query: | ||
type: string | ||
description: The query string | ||
result_column: | ||
type: string | ||
description: The result column name | ||
default: generated_texts | ||
required: | ||
- query | ||
responses: | ||
'200': | ||
description: Successful response | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
properties: | ||
result: | ||
oneOf: | ||
- type: string | ||
- type: array | ||
items: | ||
type: string | ||
description: The result text or list of texts | ||
retrieved_passage: | ||
type: array | ||
items: | ||
type: object | ||
properties: | ||
content: | ||
type: string | ||
doc_id: | ||
type: string | ||
filepath: | ||
type: string | ||
nullable: true | ||
file_page: | ||
type: integer | ||
nullable: true | ||
start_idx: | ||
type: integer | ||
nullable: true | ||
end_idx: | ||
type: integer | ||
nullable: true | ||
|
||
/v1/stream: | ||
post: | ||
summary: Stream generated text with retrieved passages | ||
description: > | ||
This endpoint streams the generated text line by line. The `retrieved_passage` | ||
is sent first, followed by the `result` streamed incrementally. | ||
requestBody: | ||
required: true | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
properties: | ||
query: | ||
type: string | ||
description: The query string | ||
result_column: | ||
type: string | ||
description: The result column name | ||
default: generated_texts | ||
required: | ||
- query | ||
responses: | ||
'200': | ||
description: Successful response with streaming | ||
content: | ||
text/event-stream: | ||
schema: | ||
type: object | ||
properties: | ||
result: | ||
oneOf: | ||
- type: string | ||
- type: array | ||
items: | ||
type: string | ||
description: The result text or list of texts (streamed line by line) | ||
retrieved_passage: | ||
type: array | ||
items: | ||
type: object | ||
properties: | ||
content: | ||
type: string | ||
doc_id: | ||
type: string | ||
filepath: | ||
type: string | ||
nullable: true | ||
file_page: | ||
type: integer | ||
nullable: true | ||
start_idx: | ||
type: integer | ||
nullable: true | ||
end_idx: | ||
type: integer | ||
nullable: true | ||
|
||
/version: | ||
get: | ||
summary: Get the API version | ||
description: Returns the current version of the API as a string. | ||
responses: | ||
'200': | ||
description: Successful response | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
properties: | ||
version: | ||
type: string | ||
description: The version of the API |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ pytest-env | |
pytest-xdist | ||
pytest-asyncio | ||
aioresponses | ||
asyncstdlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
node_lines: | ||
- node_line_name: retrieve_node_line | ||
nodes: | ||
- node_type: retrieval # represents run_node function | ||
strategy: # essential for every node | ||
metrics: [retrieval_f1, retrieval_recall] | ||
top_k: 10 # node param, which adapt to every module in this node. | ||
modules: | ||
- module_type: bm25 # for testing env variable | ||
bm25_tokenizer: [ facebook/opt-125m, porter_stemmer ] | ||
- node_type: prompt_maker | ||
strategy: | ||
metrics: [bleu] | ||
generator_modules: | ||
- module_type: llama_index_llm | ||
llm: mock | ||
modules: | ||
- module_type: fstring | ||
prompt: "Tell me something about the question: {query} \n\n {retrieved_contents}" | ||
- node_type: generator | ||
strategy: | ||
metrics: | ||
- metric_name: bleu | ||
modules: | ||
- module_type: openai_llm | ||
llm: gpt-4o-mini | ||
temperature: 0.5 |