Skip to content

Commit

Permalink
Refactor API server with streaming and passage return (#810)
Browse files Browse the repository at this point in the history
* add deploy package instead of deploy.py and make OpenAPI spec for new API server

* moved from deploy.py

* make /v1/run api

* add data format extra columns

* add stream async generator functions at the generators

* refactor api server to use extract_retrieve_passage for simplicity

* checkpoint before implementing quart

* streaming working!!!

* edit documentation and fix error at GradioRunner

* resolve test errors

* refactor API endpoint docs

---------

Co-authored-by: jeffrey <[email protected]>
vkehfdl1 and jeffrey authored Oct 8, 2024
1 parent 504dfcb commit 98c35c0
Showing 22 changed files with 897 additions and 126 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -366,14 +366,17 @@ You can run this pipeline as an API server.
Check out the API endpoint at [here](deploy/api_endpoint.md).

```python
from autorag.deploy import Runner
import nest_asyncio
from autorag.deploy import ApiRunner

runner = Runner.from_trial_folder('/your/path/to/trial_dir')
nest_asyncio.apply()

runner = ApiRunner.from_trial_folder('/your/path/to/trial_dir')
runner.run_api_server()
```

```bash
autorag run_api --config_path your/path/to/pipeline.yaml --host 0.0.0.0 --port 8000
autorag run_api --trial_dir your/path/to/trial_dir --host 0.0.0.0 --port 8000
```

The cli command uses extracted config YAML file. If you want to know it more, check out [here](https://docs.auto-rag.com/tutorial.html#extract-pipeline-and-evaluate-test-dataset).
21 changes: 17 additions & 4 deletions autorag/cli.py
Original file line number Diff line number Diff line change
@@ -7,10 +7,11 @@
from typing import Optional

import click
import nest_asyncio

from autorag import dashboard
from autorag.deploy import Runner
from autorag.deploy import extract_best_config as original_extract_best_config
from autorag.deploy.api import ApiRunner
from autorag.evaluator import Evaluator
from autorag.validator import Validator

@@ -48,15 +49,27 @@ def evaluate(config, qa_data_path, corpus_data_path, project_dir):


@click.command()
@click.option("--config_path", type=str, help="Path to extracted config yaml file.")
@click.option(
"--config_path", type=str, help="Path to extracted config yaml file.", default=None
)
@click.option("--host", type=str, default="0.0.0.0", help="Host address")
@click.option("--port", type=int, default=8000, help="Port number")
@click.option(
"--trial_dir",
type=click.Path(file_okay=False, dir_okay=True, exists=True),
default=None,
help="Path to trial directory.",
)
@click.option(
"--project_dir", help="Path to project directory.", type=str, default=None
)
def run_api(config_path, host, port, project_dir):
runner = Runner.from_yaml(config_path, project_dir=project_dir)
def run_api(config_path, host, port, trial_dir, project_dir):
if trial_dir is None:
runner = ApiRunner.from_yaml(config_path, project_dir=project_dir)
else:
runner = ApiRunner.from_trial_folder(trial_dir)
logger.info(f"Running API server at {host}:{port}...")
nest_asyncio.apply()
runner.run_api_server(host, port)


2 changes: 1 addition & 1 deletion autorag/data/qa/schema.py
Original file line number Diff line number Diff line change
@@ -79,7 +79,7 @@ def to_parquet(self, save_path: str):
"""
if not save_path.endswith(".parquet"):
raise ValueError("save_path must be ended with .parquet")
save_df = self.data[["doc_id", "contents", "metadata"]].reset_index(drop=True)
save_df = self.data.reset_index(drop=True)
save_df.to_parquet(save_path)

def batch_apply(
9 changes: 9 additions & 0 deletions autorag/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .base import (
extract_node_line_names,
extract_node_strategy,
summary_df_to_yaml,
extract_best_config,
Runner,
)
from .api import ApiRunner
from .gradio import GradioRunner
227 changes: 227 additions & 0 deletions autorag/deploy/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import logging
import os
import pathlib
import uuid
from typing import Dict, Optional, List, Union

import pandas as pd
from quart import Quart, request, jsonify
from quart.helpers import stream_with_context
from pydantic import BaseModel, ValidationError

from autorag.deploy.base import BaseRunner
from autorag.nodes.generator.base import BaseGenerator
from autorag.utils import fetch_contents

logger = logging.getLogger("AutoRAG")

deploy_dir = pathlib.Path(__file__).parent
root_dir = pathlib.Path(__file__).parent.parent

VERSION_PATH = os.path.join(root_dir, "VERSION")


class QueryRequest(BaseModel):
query: str
result_column: Optional[str] = "generated_texts"


class RetrievedPassage(BaseModel):
content: str
doc_id: str
filepath: Optional[str] = None
file_page: Optional[int] = None
start_idx: Optional[int] = None
end_idx: Optional[int] = None


class RunResponse(BaseModel):
result: Union[str, List[str]]
retrieved_passage: List[RetrievedPassage]


class VersionResponse(BaseModel):
version: str


empty_retrieved_passage = RetrievedPassage(
content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
)


class ApiRunner(BaseRunner):
def __init__(self, config: Dict, project_dir: Optional[str] = None):
super().__init__(config, project_dir)
self.app = Quart(__name__)

data_dir = os.path.join(project_dir, "data")
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
self.__add_api_route()

def __add_api_route(self):
@self.app.route("/v1/run", methods=["POST"])
async def run_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400

previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)

# Simulate processing the query
generated_text = previous_result[data.result_column].tolist()[0]
retrieved_passage = self.extract_retrieve_passage(previous_result)

response = RunResponse(
result=generated_text, retrieved_passage=retrieved_passage
)

return jsonify(response.model_dump()), 200

@self.app.route("/v1/stream", methods=["POST"])
async def stream_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400

@stream_with_context
async def generate():
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution

for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if not isinstance(module_instance, BaseGenerator):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(
columns=duplicated_columns
)
previous_result = pd.concat(
[drop_previous_result, new_result], axis=1
)
else:
retrieved_passages = self.extract_retrieve_passage(
previous_result
)
response = RunResponse(
result="", retrieved_passage=retrieved_passages
)
yield response.model_dump_json().encode("utf-8")
# Start streaming of the result
assert len(previous_result) == 1
prompt: str = previous_result["prompts"].tolist()[0]
async for delta in module_instance.stream(
prompt=prompt, **module_param
):
response = RunResponse(
result=delta,
retrieved_passage=[empty_retrieved_passage],
)
yield response.model_dump_json().encode("utf-8")

return generate(), 200, {"X-Something": "value"}

@self.app.route("/version", methods=["GET"])
def get_version():
with open(VERSION_PATH, "r") as f:
version = f.read().strip()
response = VersionResponse(version=version)
return jsonify(response.model_dump()), 200

def run_api_server(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
"""
Run the pipeline as api server.
You can send POST request to `http://host:port/run` with json body like below:
.. Code:: json
{
"query": "your query",
"result_column": "generated_texts"
}
And it returns json response like below:
.. Code:: json
{
"answer": "your answer"
}
:param host: The host of the api server.
:param port: The port of the api server.
:param kwargs: Other arguments for Flask app.run.
"""
logger.info(f"Run api server at {host}:{port}")
self.app.run(host=host, port=port, **kwargs)

def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
if "path" in self.corpus_df.columns:
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
0
]
else:
paths = [None] * len(retrieved_ids)
metadatas = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="metadata"
)[0]
if "start_end_idx" in self.corpus_df.columns:
start_end_indices = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
)[0]
else:
start_end_indices = [None] * len(retrieved_ids)
return list(
map(
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
content=content,
doc_id=doc_id,
filepath=path,
file_page=metadata.get("page", None),
start_idx=start_end_idx[0] if start_end_idx else None,
end_idx=start_end_idx[1] if start_end_idx else None,
),
contents,
retrieved_ids,
paths,
metadatas,
start_end_indices,
)
)
83 changes: 7 additions & 76 deletions autorag/deploy.py → autorag/deploy/base.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,11 @@

import pandas as pd
import yaml
from pydantic import BaseModel
import gradio as gr
from flask import Flask, request

from autorag.support import get_support_modules
from autorag.utils.util import load_summary_file


logger = logging.getLogger("AutoRAG")


@@ -122,12 +120,12 @@ def extract_best_config(trial_path: str, output_path: Optional[str] = None) -> D
return yaml_dict


class Runner:
class BaseRunner:
def __init__(self, config: Dict, project_dir: Optional[str] = None):
self.config = config
project_dir = os.getcwd() if project_dir is None else project_dir
self.app = Flask(__name__)
self.__add_api_route()
# self.app = Flask(__name__)
# self.__add_api_route()

# init modules
node_lines = deepcopy(self.config["node_lines"])
@@ -158,7 +156,7 @@ def from_yaml(cls, yaml_path: str, project_dir: Optional[str] = None):
:param yaml_path: The path of the YAML file.
:param project_dir: The path of the project directory.
Default is the current directory.
Default is the current directory.
:return: Initialized Runner.
"""
with open(yaml_path, "r") as f:
@@ -182,6 +180,8 @@ def from_trial_folder(cls, trial_path: str):
config = extract_best_config(trial_path)
return cls(config, project_dir=os.path.dirname(trial_path))


class Runner(BaseRunner):
def run(self, query: str, result_column: str = "generated_texts"):
"""
Run the pipeline with query.
@@ -214,72 +214,3 @@ def run(self, query: str, result_column: str = "generated_texts"):
previous_result = pd.concat([drop_previous_result, new_result], axis=1)

return previous_result[result_column].tolist()[0]

def __add_api_route(self):
@self.app.route("/run", methods=["POST"])
def run_pipeline():
runner_input = RunnerInput(**request.json)
query = runner_input.query
result_column = runner_input.result_column
result = self.run(query, result_column)
return {result_column: result}

def run_api_server(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
"""
Run the pipeline as api server.
You can send POST request to `http://host:port/run` with json body like below:
.. Code:: json
{
"query": "your query",
"result_column": "generated_texts"
}
And it returns json response like below:
.. Code:: json
{
"answer": "your answer"
}
:param host: The host of the api server.
:param port: The port of the api server.
:param kwargs: Other arguments for Flask app.run.
"""
logger.info(f"Run api server at {host}:{port}")
self.app.run(host=host, port=port, **kwargs)

def run_web(
self,
server_name: str = "0.0.0.0",
server_port: int = 7680,
share: bool = False,
**kwargs,
):
"""
Run web interface to interact pipeline.
You can access the web interface at `http://server_name:server_port` in your browser
:param server_name: The host of the web. Default is 0.0.0.0.
:param server_port: The port of the web. Default is 7680.
:param share: Whether to create a publicly shareable link. Default is False.
:param kwargs: Other arguments for gr.ChatInterface.launch.
"""

logger.info(f"Run web interface at http://{server_name}:{server_port}")

def get_response(message, _):
return self.run(message)

gr.ChatInterface(
get_response, title="📚 AutoRAG", retry_btn=None, undo_btn=None
).launch(
server_name=server_name, server_port=server_port, share=share, **kwargs
)


class RunnerInput(BaseModel):
query: str
result_column: str = "generated_texts"
74 changes: 74 additions & 0 deletions autorag/deploy/gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import uuid

import pandas as pd

from autorag.deploy.base import BaseRunner

import gradio as gr


logger = logging.getLogger("AutoRAG")


class GradioRunner(BaseRunner):
def run_web(
self,
server_name: str = "0.0.0.0",
server_port: int = 7680,
share: bool = False,
**kwargs,
):
"""
Run web interface to interact pipeline.
You can access the web interface at `http://server_name:server_port` in your browser
:param server_name: The host of the web. Default is 0.0.0.0.
:param server_port: The port of the web. Default is 7680.
:param share: Whether to create a publicly shareable link. Default is False.
:param kwargs: Other arguments for gr.ChatInterface.launch.
"""

logger.info(f"Run web interface at http://{server_name}:{server_port}")

def get_response(message, _):
return self.run(message)

gr.ChatInterface(
get_response, title="📚 AutoRAG", retry_btn=None, undo_btn=None
).launch(
server_name=server_name, server_port=server_port, share=share, **kwargs
)

def run(self, query: str, result_column: str = "generated_texts"):
"""
Run the pipeline with query.
The loaded pipeline must start with a single query,
so the first module of the pipeline must be `query_expansion` or `retrieval` module.
:param query: The query of the user.
:param result_column: The result column name for the answer.
Default is `generated_texts`, which is the output of the `generation` module.
:return: The result of the pipeline.
"""
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)

return previous_result[result_column].tolist()[0]
135 changes: 135 additions & 0 deletions autorag/deploy/swagger.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
openapi: 3.0.0
info:
title: Example API
version: 1.0.0
paths:
/v1/run:
post:
summary: Run a query and get generated text with retrieved passages
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
query:
type: string
description: The query string
result_column:
type: string
description: The result column name
default: generated_texts
required:
- query
responses:
'200':
description: Successful response
content:
application/json:
schema:
type: object
properties:
result:
oneOf:
- type: string
- type: array
items:
type: string
description: The result text or list of texts
retrieved_passage:
type: array
items:
type: object
properties:
content:
type: string
doc_id:
type: string
filepath:
type: string
nullable: true
file_page:
type: integer
nullable: true
start_idx:
type: integer
nullable: true
end_idx:
type: integer
nullable: true

/v1/stream:
post:
summary: Stream generated text with retrieved passages
description: >
This endpoint streams the generated text line by line. The `retrieved_passage`
is sent first, followed by the `result` streamed incrementally.
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
query:
type: string
description: The query string
result_column:
type: string
description: The result column name
default: generated_texts
required:
- query
responses:
'200':
description: Successful response with streaming
content:
text/event-stream:
schema:
type: object
properties:
result:
oneOf:
- type: string
- type: array
items:
type: string
description: The result text or list of texts (streamed line by line)
retrieved_passage:
type: array
items:
type: object
properties:
content:
type: string
doc_id:
type: string
filepath:
type: string
nullable: true
file_page:
type: integer
nullable: true
start_idx:
type: integer
nullable: true
end_idx:
type: integer
nullable: true

/version:
get:
summary: Get the API version
description: Returns the current version of the API as a string.
responses:
'200':
description: Successful response
content:
application/json:
schema:
type: object
properties:
version:
type: string
description: The version of the API
4 changes: 4 additions & 0 deletions autorag/nodes/generator/base.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,10 @@ def structured_output(self, prompts: List[str], output_cls):
result.append(None)
return result

@abc.abstractmethod
async def stream(self, prompt: str, **kwargs):
pass


def generator_node(func):
@functools.wraps(func)
6 changes: 6 additions & 0 deletions autorag/nodes/generator/llama_index_llm.py
Original file line number Diff line number Diff line change
@@ -85,3 +85,9 @@ def _pure(
tokenized_ids = tokenizer(generated_texts).data["input_ids"]
pseudo_log_probs = list(map(lambda x: [0.5] * len(x), tokenized_ids))
return generated_texts, tokenized_ids, pseudo_log_probs

async def stream(self, prompt: str, **kwargs):
async for completion_response in await self.llm_instance.astream_complete(
prompt
):
yield completion_response.text
31 changes: 30 additions & 1 deletion autorag/nodes/generator/openai_llm.py
Original file line number Diff line number Diff line change
@@ -178,11 +178,40 @@ def structured_output(self, prompts: List[str], output_cls, **kwargs):
result = loop.run_until_complete(process_batch(tasks, self.batch))
return result

async def stream(self, prompt: str, **kwargs):
if kwargs.get("logprobs") is not None:
kwargs.pop("logprobs")
logger.warning(
"parameter logprob does not effective. It always set to False."
)
if kwargs.get("n") is not None:
kwargs.pop("n")
logger.warning("parameter n does not effective. It always set to 1.")

prompt = truncate_by_token(prompt, self.tokenizer, self.max_token_size)

openai_chat_params = pop_params(self.client.chat.completions.create, kwargs)

stream = await self.client.chat.completions.create(
model=self.llm,
messages=[
{"role": "user", "content": prompt},
],
logprobs=False,
n=1,
stream=True,
**openai_chat_params,
)
result = ""
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
result += chunk.choices[0].delta.content
yield result

async def get_structured_result(self, prompt: str, output_cls, **kwargs):
response = await self.client.beta.chat.completions.parse(
model=self.llm,
messages=[
{"role": "system", "content": "Structured Output"},
{"role": "user", "content": prompt},
],
response_format=output_cls,
3 changes: 3 additions & 0 deletions autorag/nodes/generator/vllm.py
Original file line number Diff line number Diff line change
@@ -106,3 +106,6 @@ def _pure(
to_list(generated_token_ids),
to_list(generated_log_probs),
)

async def stream(self, prompt: str, **kwargs):
raise NotImplementedError
23 changes: 23 additions & 0 deletions docs/source/data_creation/data_format.md
Original file line number Diff line number Diff line change
@@ -139,6 +139,23 @@ But from an early version of AutoRAG, it only supports text.
Plus, we have plans to support chunking optimization for your data.
```

### path (Optional, but recommended)

The origin path of the passage. When you insert this, you will track what path the passage comes from.
It is really useful for debugging or displaying the origin of the passage when the passage is retrieved.

When you use AutoRAG original parsing and chunking, this will be automatically filled.

The type is `string`.

### start_end_idx (Optional but recommended)

The start and end index of the passage in the original parsed document. With this, you can update QA with this new corpus when you have raw data.

This will be automatically filled when you use AutoRAG original parsing and chunking.

The type is the tuple of int (start, end).

### metadata

Metadata for your passages.
@@ -148,6 +165,12 @@ You must include `last_modified_datetime` key at metadata.
We recommend you to include modified datetime of your passages, but it is okay to put `datetime.now()` if you don't want to use time-related feature.
The value of `last_modified_datetime` must be an instance of python `datetime.datetime` class.

For optional metadata, you can put 'page'. This will be helpful when you want to display the origin of the passage.

Plus, for using prev_next_augmenter, you must include `prev_id` and `next_id` in the metadata.

These will be filled when you use AutoRAG original parsing and chunking.

```{tip}
If you don't have any metadata, you can put an empty dictionary.
It will create a default metadata for you. (like `last_modified_datetime` with `datetime.now()`)
202 changes: 184 additions & 18 deletions docs/source/deploy/api_endpoint.md
Original file line number Diff line number Diff line change
@@ -9,45 +9,211 @@ myst:

## Running API server

As mentioned in the tutorial, you can run api server from the extracted YAML file or trial folder as follows:
As mentioned in the tutorial, you can run api server as follows:

```python
from autorag.deploy import Runner
from autorag.deploy import ApiRunner
import nest_asyncio

runner = Runner.from_yaml('your/path/to/pipeline.yaml')
nest_asyncio.apply()

runner = ApiRunner.from_yaml('your/path/to/pipeline.yaml', project_dir='your/project/directory')
runner.run_api_server()
```

runner = Runner.from_trial_folder('your/path/to/trial_folder')
or

```python
from autorag.deploy import ApiRunner
import nest_asyncio

nest_asyncio.apply()

runner = ApiRunner.from_trial_folder('/your/path/to/trial_dir')
runner.run_api_server()
```

```bash
autorag run_api --config_path your/path/to/pipeline.yaml --host 0.0.0.0 --port 8000
autorag run_api --trial_dir /trial/dir/0 --host 0.0.0.0 --port 8000
```

## API Endpoint

You can use AutoRAG api server using `/run` endpoint.
It is a `POST` operation, and you can specify a user query as `query` and result column as `result_column` in the request body.
Then, you can get a response with result looks like `{'result_column': result}`
The `result_column` is the same as the `result_column` you specified in the request body.
And the `result_column` must be one of the last output of your pipeline. The default is 'answer.'
Certainly! To generate API endpoint documentation in Markdown format from the provided OpenAPI specification, we need to break down each endpoint and describe its purpose, request parameters, and response structure. Here's how you can document the API:

---

## Example API Documentation

### Version: 1.0.0

---

### Endpoints

#### 1. `/v1/run` (POST)

- **Summary**: Run a query and get generated text with retrieved passages.
- **Request Body**:
- **Content Type**: `application/json`
- **Schema**:
- **Properties**:
- `query` (string, required): The query string.
- `result_column` (string, optional): The result column name. Default is `generated_texts`.
- **Responses**:
- **200 OK**:
- **Content Type**: `application/json`
- **Schema**:
- **Properties**:
- `result` (string or array of strings): The result text or list of texts.
- `retrieved_passage` (array of objects): List of retrieved passages.
- **Properties**:
- `content` (string): The content of the passage.
- `doc_id` (string): Document ID.
- `filepath` (string, nullable): File path.
- `file_page` (integer, nullable): File page number.
- `start_idx` (integer, nullable): Start index.
- `end_idx` (integer, nullable): End index.

---

#### 2. `/v1/stream` (POST)

- **Summary**: Stream generated text with retrieved passages.
- **Description**: This endpoint streams the generated text line by line. The `retrieved_passage` is sent first, followed by the `result` streamed incrementally.
- **Request Body**:
- **Content Type**: `application/json`
- **Schema**:
- **Properties**:
- `query` (string, required): The query string.
- `result_column` (string, optional): The result column name. Default is `generated_texts`.
- **Responses**:
- **200 OK**:
- **Content Type**: `text/event-stream`
- **Schema**:
- **Properties**:
- `result` (string or array of strings): The result text or list of texts (streamed line by line).
- `retrieved_passage` (array of objects): List of retrieved passages.
- **Properties**:
- `content` (string): The content of the passage.
- `doc_id` (string): Document ID.
- `filepath` (string, nullable): File path.
- `file_page` (integer, nullable): File page number.
- `start_idx` (integer, nullable): Start index.
- `end_idx` (integer, nullable): End index.

---

#### 3. `/version` (GET)

- **Summary**: Get the API version.
- **Description**: Returns the current version of the API as a string.
- **Responses**:
- **200 OK**:
- **Content Type**: `application/json`
- **Schema**:
- **Properties**:
- `version` (string): The version of the API.

---

## API client usage example

Certainly! Below, I'll provide both Python sample code using the `requests` library and a `curl` command for each of the API endpoints described in the OpenAPI specification.

### Python Sample Code

First, ensure you have the `requests` library installed. You can install it using pip if you haven't already:

```bash
curl -X POST "http://your-host:your-port/run" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"query\":\"your question\", \"result_column\":\"your_result_column\"}"
pip install requests
```

Here's the Python client code for each endpoint:

```python
import requests
import json

# Base URL of the API
BASE_URL = "http://example.com:8000" # Replace with the actual base URL of the API

url = "http://your-host:your-port/run"
payload = "{\"query\":\"your question\", \"result_column\":\"your_result_column\"}"
headers = {
'accept': "application/json",
'Content-Type': "application/json"
def run_query(query, result_column="generated_texts"):
url = f"{BASE_URL}/v1/run"
payload = {
"query": query,
"result_column": result_column
}
response = requests.post(url, json=payload)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()

def stream_query(query, result_column="generated_texts"):
url = f"{BASE_URL}/v1/stream"
payload = {
"query": query,
"result_column": result_column
}
response = requests.post(url, json=payload, stream=True)
if response.status_code == 200:
for i, chunk in enumerate(response.iter_content(chunk_size=None)):
if chunk:
# Decode the chunk and print it
data = json.loads(chunk.decode("utf-8"))
if i == 0:
retrieved_passages = data["retrieved_passage"] # The retrieved passages
print(data["result"], end="")
else:
response.raise_for_status()

def get_version():
url = f"{BASE_URL}/version"
response = requests.get(url)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()

# Example usage
if __name__ == "__main__":
# Run a query
result = run_query("example query")
print("Run Query Result:", result)

# Stream a query
print("Stream Query Result:")
stream_query("example query")

# Get API version
version = get_version()
print("API Version:", version)
```

### `curl` Commands

Here are the equivalent `curl` commands for each endpoint:

#### `/v1/run` (POST)

```bash
curl -X POST "http://example.com/v1/run" \
-H "Content-Type: application/json" \
-d '{"query": "example query", "result_column": "generated_texts"}'
```

#### `/v1/stream` (POST)

```bash
curl -X POST "http://example.com/v1/stream" \
-H "Content-Type: application/json" \
-d '{"query": "example query", "result_column": "generated_texts"}' \
--no-buffer
```

response = requests.request("POST", url, data=payload, headers=headers)
#### `/version` (GET)

print(response.text)
```bash
curl -X GET "http://example.com/version"
```
24 changes: 15 additions & 9 deletions docs/source/tutorial.md
Original file line number Diff line number Diff line change
@@ -232,23 +232,29 @@ You can run this pipeline as an API server.
Check out the API endpoint at [here](deploy/api_endpoint.md).

```python
from autorag.deploy import Runner
from autorag.deploy import ApiRunner
import nest_asyncio
runner = Runner.from_yaml('your/path/to/pipeline.yaml')
nest_asyncio.apply()
runner = ApiRunner.from_yaml('your/path/to/pipeline.yaml', project_dir='your/project/directory')
runner.run_api_server()
```

or

```python
from autorag.deploy import Runner
from autorag.deploy import ApiRunner
import nest_asyncio
runner = Runner.from_trial_folder('/your/path/to/trial_dir')
nest_asyncio.apply()
runner = ApiRunner.from_trial_folder('/your/path/to/trial_dir')
runner.run_api_server()
```

```bash
autorag run_api --config_path your/path/to/pipeline.yaml --host 0.0.0.0 --port 8000
autorag run_api --trial_dir /trial/dir/0 --host 0.0.0.0 --port 8000
```

```{admonition} Want to specify project folder?
@@ -262,16 +268,16 @@ you can run this pipeline as a web interface.
Check out the web interface at [here](deploy/web.md).

```python
from autorag.deploy import Runner
from autorag.deploy import GradioRunner
runner = Runner.from_yaml('your/path/to/pipeline.yaml')
runner = GradioRunner.from_yaml('your/path/to/pipeline.yaml')
runner.run_web()
```

```python
from autorag.deploy import Runner
from autorag.deploy import GradioRunner
runner = Runner.from_trial_folder('/your/path/to/trial_dir')
runner = GradioRunner.from_trial_folder('/your/path/to/trial_dir')
runner.run_web()
```

4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@ rouge_score # for rouge score
rich # for pretty logging
chromadb>=0.5.0 # for vectordb retrieval
click # for cli
Flask # for api server
torch # for monot5 reranker
sentencepiece # for monot5 reranker
cohere>=5.8.0 # for cohere services
@@ -27,6 +26,9 @@ llmlingua # for longllmlingua
peft
optimum[openvino,nncf] # for openvino reranker

### API server ###
quart

### LlamaIndex ###
llama-index>=0.11.0
llama-index-core>=0.11.0
8 changes: 7 additions & 1 deletion tests/autorag/data/qa/test_schema.py
Original file line number Diff line number Diff line change
@@ -186,7 +186,13 @@ def test_update_corpus():
"generation_gt",
}
loaded_corpus = pd.read_parquet(corpus_path.name, engine="pyarrow")
assert set(loaded_corpus.columns) == {"doc_id", "contents", "metadata"}
assert set(loaded_corpus.columns) == {
"doc_id",
"contents",
"metadata",
"path",
"start_end_idx",
}
corpus_path.close()
os.unlink(corpus_path.name)
qa_path.close()
29 changes: 29 additions & 0 deletions tests/autorag/nodes/generator/test_llama_index_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import os
from unittest.mock import patch

import pandas as pd
@@ -14,8 +16,11 @@
check_generated_tokens,
check_generated_log_probs,
)
from tests.delete_tests import is_github_action
from tests.mock import MockLLM

logger = logging.getLogger("AutoRAG")


@pytest.fixture
def llama_index_llm_instance():
@@ -89,3 +94,27 @@ class TestResponse(BaseModel):
assert output.phone_number == "1234567890"
assert output.age == 30
assert output.is_dead is False


@pytest.mark.skipif(
is_github_action(),
reason="Skipping this test on GitHub Actions because it uses the real OpenAI API.",
)
@pytest.mark.asyncio()
async def test_llama_index_llm_stream():
import asyncstdlib as a

llm_instance = LlamaIndexLLM(
project_dir=".",
llm="openai",
model="gpt-4o-mini",
api_key=os.getenv("OPENAI_API_KEY"),
)
result = []
async for i, s in a.enumerate(
llm_instance.stream("Hello. Tell me about who is Kai Havertz")
):
assert isinstance(s, str)
result.append(s)
if i >= 1:
assert len(result[i]) >= len(result[i - 1])
20 changes: 20 additions & 0 deletions tests/autorag/nodes/generator/test_openai.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
check_generated_tokens,
check_generated_log_probs,
)
from tests.delete_tests import is_github_action
from tests.mock import mock_openai_chat_create
import openai.resources.beta.chat
from openai.types.chat import (
@@ -132,3 +133,22 @@ def test_openai_llm_structured():
llm = OpenAILLM(project_dir=".", llm="gpt-3.5-turbo")
with pytest.raises(ValueError):
llm.structured_output([prompt], TestResponse)


@pytest.mark.skipif(
is_github_action(),
reason="Skipping this test on GitHub Actions because it uses the real OpenAI API.",
)
@pytest.mark.asyncio()
async def test_openai_llm_stream():
import asyncstdlib as a

llm_instance = OpenAILLM(project_dir=".", llm="gpt-4o-mini-2024-07-18")
result = []
async for i, s in a.enumerate(
llm_instance.stream("Hello. Tell me about who is Kai Havertz")
):
assert isinstance(s, str)
result.append(s)
if i >= 1:
assert len(result[i]) >= len(result[i - 1])
81 changes: 69 additions & 12 deletions tests/autorag/test_deploy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import logging
import os
import pathlib
import tempfile

import nest_asyncio
import pandas as pd
import pytest
import yaml
@@ -13,12 +16,15 @@
extract_node_line_names,
extract_node_strategy,
)
from autorag.deploy.api import ApiRunner
from autorag.evaluator import Evaluator
from tests.delete_tests import is_github_action

root_dir = pathlib.PurePath(os.path.dirname(os.path.realpath(__file__))).parent
resource_dir = os.path.join(root_dir, "resources")

logger = logging.getLogger("AutoRAG")


@pytest.fixture
def evaluator():
@@ -31,6 +37,12 @@ def evaluator():
yield evaluator


@pytest.fixture
def evaluator_trial_done(evaluator):
evaluator.start_trial(os.path.join(resource_dir, "simple_with_llm.yaml"))
yield evaluator


@pytest.fixture
def full_config():
yaml_path = os.path.join(resource_dir, "full.yaml")
@@ -200,21 +212,66 @@ def test_runner_full(evaluator):
def test_runner_api_server(evaluator):
project_dir = evaluator.project_dir
evaluator.start_trial(os.path.join(resource_dir, "simple_mock.yaml"))
runner = Runner.from_trial_folder(os.path.join(project_dir, "0"))
runner = ApiRunner.from_trial_folder(os.path.join(project_dir, "0"))

client = runner.app.test_client()

# Use the TestClient to make a request to the server
response = client.post(
"/run",
json={
"query": "What is the best movie in Korea? Have Korea movie ever won Oscar?",
"result_column": "retrieved_contents",
},
)
assert response.status_code == 200
assert "retrieved_contents" in response.json
retrieved_contents = response.json["retrieved_contents"]
async def post_to_server():
# Use the TestClient to make a request to the server
response = await client.post(
"/v1/run",
json={
"query": "What is the best movie in Korea? Have Korea movie ever won Oscar?",
"result_column": "retrieved_contents",
},
)
json_response = await response.get_json()
return json_response, response.status_code

nest_asyncio.apply()

response_json, response_status_code = asyncio.run(post_to_server())
assert response_status_code == 200
assert "result" in response_json
retrieved_contents = response_json["result"]
assert len(retrieved_contents) == 10
assert isinstance(retrieved_contents, list)
assert isinstance(retrieved_contents[0], str)

retrieved_contents = response_json["retrieved_passage"]
assert len(retrieved_contents) == 10
assert isinstance(retrieved_contents[0]["content"], str)
assert isinstance(retrieved_contents[0]["doc_id"], str)
assert retrieved_contents[0]["filepath"] is None
assert retrieved_contents[0]["file_page"] is None
assert retrieved_contents[0]["start_idx"] is None
assert retrieved_contents[0]["end_idx"] is None


@pytest.mark.skip(reason="This test is not working")
def test_runner_api_server_stream(evaluator_trial_done):
project_dir = evaluator_trial_done.project_dir
runner = ApiRunner.from_trial_folder(os.path.join(project_dir, "0"))
client = runner.app.test_client()

async def post_to_server():
# Use the TestClient to make a request to the server
async with client.request(
"/v1/stream",
method="POST",
headers={"Content-Type": "application/json"},
query_string={
"query": "What is the best movie in Korea? Have Korea movie ever won Oscar?",
},
) as connection:
response = await connection.receive()
# Ensure the response status code is 200
assert connection.status_code == 200

# Collect streamed data
streamed_data = []
async for data in response.body:
streamed_data.append(data)

nest_asyncio.apply()
asyncio.run(post_to_server())
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -3,3 +3,4 @@ pytest-env
pytest-xdist
pytest-asyncio
aioresponses
asyncstdlib
27 changes: 27 additions & 0 deletions tests/resources/simple_with_llm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
node_lines:
- node_line_name: retrieve_node_line
nodes:
- node_type: retrieval # represents run_node function
strategy: # essential for every node
metrics: [retrieval_f1, retrieval_recall]
top_k: 10 # node param, which adapt to every module in this node.
modules:
- module_type: bm25 # for testing env variable
bm25_tokenizer: [ facebook/opt-125m, porter_stemmer ]
- node_type: prompt_maker
strategy:
metrics: [bleu]
generator_modules:
- module_type: llama_index_llm
llm: mock
modules:
- module_type: fstring
prompt: "Tell me something about the question: {query} \n\n {retrieved_contents}"
- node_type: generator
strategy:
metrics:
- metric_name: bleu
modules:
- module_type: openai_llm
llm: gpt-4o-mini
temperature: 0.5

0 comments on commit 98c35c0

Please sign in to comment.