Skip to content

Commit

Permalink
[HotFix] AutoRAG api error fix (#1032)
Browse files Browse the repository at this point in the history
* Fix error at the Api with AutoRAG-ingested corpus

* dump version 0.3.11rc2

* Fix missing import
  • Loading branch information
vkehfdl1 authored Dec 3, 2024
1 parent cf2df95 commit e4cd040
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 2 deletions.
2 changes: 1 addition & 1 deletion autorag/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.11rc1
0.3.11rc2
3 changes: 2 additions & 1 deletion autorag/deploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from autorag.deploy.base import BaseRunner
from autorag.nodes.generator.base import BaseGenerator
from autorag.nodes.promptmaker.base import BasePromptMaker
from autorag.utils import fetch_contents
from autorag.utils.util import fetch_contents, to_list

logger = logging.getLogger("AutoRAG")

Expand Down Expand Up @@ -279,6 +279,7 @@ def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
)[0]
else:
start_end_indices = [None] * len(retrieved_ids)
start_end_indices = to_list(start_end_indices)
return list(
map(
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
Expand Down
52 changes: 52 additions & 0 deletions tests/autorag/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def evaluator():
yield evaluator


@pytest.fixture
def evaluator_data_gen_by_autorag():
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as project_dir:
evaluator = Evaluator(
os.path.join(resource_dir, "dataset_sample_gen_by_autorag", "qa.parquet"),
os.path.join(
resource_dir, "dataset_sample_gen_by_autorag", "corpus.parquet"
),
project_dir=project_dir,
)
yield evaluator


@pytest.fixture
def evaluator_trial_done(evaluator):
evaluator.start_trial(os.path.join(resource_dir, "simple_with_llm.yaml"))
Expand Down Expand Up @@ -291,6 +304,45 @@ async def post_to_server_retrieve():
assert isinstance(passages[0]["score"], float)


def test_runner_api_server2(evaluator_data_gen_by_autorag):
project_dir = evaluator_data_gen_by_autorag.project_dir
evaluator_data_gen_by_autorag.start_trial(
os.path.join(resource_dir, "simple_mock_with_llm.yaml")
)
runner = ApiRunner.from_trial_folder(os.path.join(project_dir, "0"))

client = runner.app.test_client()

async def post_to_server():
# Use the TestClient to make a request to the server
response = await client.post(
"/v1/run",
json={
"query": "What is the best movie in Korea? Have Korea movie ever won Oscar?",
},
)
json_response = await response.get_json()
return json_response, response.status_code

nest_asyncio.apply()
response_json, response_status_code = asyncio.run(post_to_server())
assert response_status_code == 200
assert "result" in response_json
assert "retrieved_passage" in response_json
answer = response_json["result"]
assert isinstance(answer, str)
assert bool(answer)

retrieved_contents = response_json["retrieved_passage"]
assert len(retrieved_contents) == 10
assert isinstance(retrieved_contents[0]["content"], str)
assert isinstance(retrieved_contents[0]["doc_id"], str)
assert retrieved_contents[0]["filepath"]
assert retrieved_contents[0]["file_page"]
assert retrieved_contents[0]["start_idx"]
assert retrieved_contents[0]["end_idx"]


@pytest.mark.skip(reason="This test is not working")
def test_runner_api_server_stream(evaluator_trial_done):
project_dir = evaluator_trial_done.project_dir
Expand Down
Binary file not shown.
Binary file not shown.
36 changes: 36 additions & 0 deletions tests/resources/simple_mock_with_llm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
vectordb:
- name: chroma_default
db_type: chroma
client_type: persistent
embedding_model: mock
collection_name: openai
path: ${PROJECT_DIR}/resources/chroma
node_lines:
- node_line_name: retrieve_node_line
nodes:
- node_type: retrieval # represents run_node function
strategy: # essential for every node
metrics: [retrieval_f1, retrieval_recall]
top_k: 10 # node param, which adapt to every module in this node.
modules:
- module_type: bm25
bm25_tokenizer: [ porter_stemmer ]
- module_type: vectordb
vectordb: chroma_default
embedding_batch: 50
- node_type: prompt_maker
strategy:
metrics: [ bleu ]
generator_modules:
- module_type: llama_index_llm
llm: mock
modules:
- module_type: fstring
prompt: "Tell me something about the question: {query} \n\n {retrieved_contents}"
- node_type: generator
strategy:
metrics:
- metric_name: bleu
modules:
- module_type: llama_index_llm
llm: mock

0 comments on commit e4cd040

Please sign in to comment.