Skip to content

Commit

Permalink
fix empty index + generation synthesizer (#16785)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Nov 1, 2024
1 parent 35234d2 commit b28c19a
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 2 deletions.
111 changes: 109 additions & 2 deletions llama-index-core/llama_index/core/response_synthesizers/generation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
from typing import Any, Optional, Sequence
from typing import Any, List, Optional, Sequence

from llama_index.core.base.response.schema import RESPONSE_TYPE
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.instrumentation.events.synthesis import (
SynthesizeStartEvent,
SynthesizeEndEvent,
)
import llama_index.core.instrumentation as instrument
from llama_index.core.llms import LLM
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.response_synthesizers.base import BaseSynthesizer
from llama_index.core.schema import (
MetadataMode,
NodeWithScore,
QueryBundle,
QueryType,
)
from llama_index.core.types import RESPONSE_TEXT_TYPE


dispatcher = instrument.get_dispatcher(__name__)


class Generation(BaseSynthesizer):
def __init__(
self,
Expand Down Expand Up @@ -52,7 +68,7 @@ async def aget_response(
**response_kwargs,
)
else:
return self._llm.stream(
return await self._llm.astream(
self._input_prompt,
query_str=query_str,
**response_kwargs,
Expand All @@ -79,3 +95,94 @@ def get_response(
query_str=query_str,
**response_kwargs,
)

# NOTE: synthesize and asynthesize are copied from the base class,
# but modified to return when zero nodes are provided

@dispatcher.span
def synthesize(
self,
query: QueryType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
dispatcher.event(
SynthesizeStartEvent(
query=query,
)
)

if isinstance(query, str):
query = QueryBundle(query_str=query)

with self._callback_manager.event(
CBEventType.SYNTHESIZE,
payload={EventPayload.QUERY_STR: query.query_str},
) as event:
response_str = self.get_response(
query_str=query.query_str,
text_chunks=[
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
],
**response_kwargs,
)

additional_source_nodes = additional_source_nodes or []
source_nodes = list(nodes) + list(additional_source_nodes)

response = self._prepare_response_output(response_str, source_nodes)

event.on_end(payload={EventPayload.RESPONSE: response})

dispatcher.event(
SynthesizeEndEvent(
query=query,
response=response,
)
)
return response

@dispatcher.span
async def asynthesize(
self,
query: QueryType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
dispatcher.event(
SynthesizeStartEvent(
query=query,
)
)

if isinstance(query, str):
query = QueryBundle(query_str=query)

with self._callback_manager.event(
CBEventType.SYNTHESIZE,
payload={EventPayload.QUERY_STR: query.query_str},
) as event:
response_str = await self.aget_response(
query_str=query.query_str,
text_chunks=[
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
],
**response_kwargs,
)

additional_source_nodes = additional_source_nodes or []
source_nodes = list(nodes) + list(additional_source_nodes)

response = self._prepare_response_output(response_str, source_nodes)

event.on_end(payload={EventPayload.RESPONSE: response})

dispatcher.event(
SynthesizeEndEvent(
query=query,
response=response,
)
)
return response
40 changes: 40 additions & 0 deletions llama-index-core/tests/response_synthesizers/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest

from llama_index.core.llms import MockLLM
from llama_index.core.response_synthesizers.generation import Generation


def test_synthesize() -> None:
synthesizer = Generation(llm=MockLLM())
response = synthesizer.synthesize(query="test", nodes=[])
assert str(response) == "test"


def test_synthesize_stream() -> None:
synthesizer = Generation(llm=MockLLM(), streaming=True)
response = synthesizer.synthesize(query="test", nodes=[])

gold = "test"
i = 0
for chunk in response.response_gen:
assert chunk == gold[i]
i += 1


@pytest.mark.asyncio()
async def test_asynthesize() -> None:
synthesizer = Generation(llm=MockLLM())
response = await synthesizer.asynthesize(query="test", nodes=[])
assert str(response) == "test"


@pytest.mark.asyncio()
async def test_asynthesize_stream() -> None:
synthesizer = Generation(llm=MockLLM(), streaming=True)
response = await synthesizer.asynthesize(query="test", nodes=[])

gold = "test"
i = 0
async for chunk in response.async_response_gen():
assert chunk == gold[i]
i += 1

0 comments on commit b28c19a

Please sign in to comment.