Skip to content

Commit

Permalink
patch: add selected result to response metadata for router query engi…
Browse files Browse the repository at this point in the history
…nes, fix bug (#8483)
  • Loading branch information
jerryjliu authored Oct 26, 2023
1 parent 9f37cfa commit e06250a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 34 deletions.
59 changes: 25 additions & 34 deletions docs/examples/query_engine/RouterQueryEngine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: NumExpr detected 12 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
"NumExpr defaulting to 8 threads.\n"
]
}
],
"outputs": [],
"source": [
"import logging\n",
"import sys\n",
Expand Down Expand Up @@ -113,18 +104,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
"> [build_index_from_nodes] Total embedding token usage: 0 tokens\n",
"> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
"> [build_index_from_nodes] Total embedding token usage: 17038 tokens\n"
]
}
],
"outputs": [],
"source": [
"summary_index = SummaryIndex(nodes, storage_context=storage_context)\n",
"vector_index = VectorStoreIndex(nodes, storage_context=storage_context)"
Expand Down Expand Up @@ -285,6 +265,16 @@
"print(str(response))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# [optional] look at selected results\n",
"print(str(response.metadata[\"selector_result\"]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -298,16 +288,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
"> [build_index_from_nodes] Total embedding token usage: 0 tokens\n"
]
}
],
"outputs": [],
"source": [
"from llama_index import SimpleKeywordTableIndex\n",
"\n",
Expand Down Expand Up @@ -351,13 +332,23 @@
")\n",
"print(str(response))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# [optional] look at selected results\n",
"print(str(response.metadata[\"selector_result\"]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "llama_index_v2",
"language": "python",
"name": "python3"
"name": "llama_index_v2"
},
"language_info": {
"codemirror_mode": {
Expand Down
16 changes: 16 additions & 0 deletions llama_index/query_engine/router_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:

final_response = selected_query_engine.query(query_bundle)

# add selected result
final_response.metadata = final_response.metadata or {}
final_response.metadata["selector_result"] = result

query_event.on_end(payload={EventPayload.RESPONSE: final_response})

return final_response
Expand Down Expand Up @@ -214,6 +218,10 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:

final_response = await selected_query_engine.aquery(query_bundle)

# add selected result
final_response.metadata = final_response.metadata or {}
final_response.metadata["selector_result"] = result

query_event.on_end(payload={EventPayload.RESPONSE: final_response})

return final_response
Expand Down Expand Up @@ -332,6 +340,10 @@ def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
else:
final_response = responses[0]

# add selected result
final_response.metadata = final_response.metadata or {}
final_response.metadata["retrieved_tools"] = query_engine_tools

query_event.on_end(payload={EventPayload.RESPONSE: final_response})

return final_response
Expand All @@ -353,6 +365,10 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
else:
final_response = responses[0]

# add selected result
final_response.metadata = final_response.metadata or {}
final_response.metadata["retrieved_tools"] = query_engine_tools

query_event.on_end(payload={EventPayload.RESPONSE: final_response})

return final_response
8 changes: 8 additions & 0 deletions llama_index/selectors/pydantic_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def from_defaults(

return cls(selector_program=program, max_outputs=max_outputs)

def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
# TODO: no accessible prompts for a base pydantic program
return {}

def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""

def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
Expand Down

0 comments on commit e06250a

Please sign in to comment.