Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support to BatchAPIs to gather evidence #687

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2385dce
Implements OpenAIBatchLLMModel
maykcaldas Nov 14, 2024
8a21055
Incorporates OpenAIBatchLLMModel to get_evidence
maykcaldas Nov 14, 2024
5f59681
Merge branch 'main' into batch_api
maykcaldas Nov 15, 2024
6f7bbb5
Merge branch 'main' into batch_api
maykcaldas Nov 16, 2024
e8dc0d0
Started anthropic batch api support implementation
maykcaldas Nov 15, 2024
899de43
Removed the skip_system argument from the new classes and tests to ma…
maykcaldas Nov 15, 2024
16c3988
Switched to async OpenAI client
maykcaldas Nov 16, 2024
d10a268
Added logging to the batch processing
maykcaldas Nov 16, 2024
0fe9aa1
Created mock server to test openAI batch API
maykcaldas Nov 18, 2024
a9ad540
Implemented batch support to Anthropic
maykcaldas Nov 18, 2024
9a0a6c4
Merge branch 'main' into batch_api
maykcaldas Nov 18, 2024
723650d
Updated uv.lock to include imports for the batch API
maykcaldas Nov 18, 2024
660bfa0
Implements tests with a mocked server for anthropic
maykcaldas Nov 18, 2024
977a025
Added type hints to satisfy the pre-commit
maykcaldas Nov 19, 2024
ee351f2
Merge branch 'main' into batch_api
maykcaldas Nov 19, 2024
293658a
Updates uv on github actions to include extra requirements
maykcaldas Nov 19, 2024
1ad1c7c
Removed the --all-extras flag from uv in github workflow
maykcaldas Nov 19, 2024
af32005
Refactored OpenAiBatchStatus and AnthropicBatchStatus to make the cod…
maykcaldas Nov 19, 2024
63e4b39
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 19, 2024
f61e629
Merge branch 'main' into batch_api
maykcaldas Nov 19, 2024
d7dbd72
Cleaned unneeded comments
maykcaldas Nov 19, 2024
7c37f6d
Updated the way the system message is passed to anthropic
maykcaldas Nov 19, 2024
de18907
changed how the file is passed to openai
maykcaldas Nov 20, 2024
3e72bd4
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 20, 2024
7c7f4b8
Avoided writing to a file when sending the batch to openAi
maykcaldas Nov 20, 2024
6c8f186
Skipped writing a file. Instead, the content is directly passed to th…
maykcaldas Nov 20, 2024
0e43a7c
Merge branch 'main' into batch_api
maykcaldas Nov 20, 2024
17c26eb
Fixed lint error
maykcaldas Nov 20, 2024
c258306
Updated the batch time limit settings name
maykcaldas Nov 20, 2024
4b8e1c3
Removed type hints from docstrings in gather_with_batch
maykcaldas Nov 20, 2024
8b5c1fa
Added exception in map_fxn_summary to treat multiple reponses
maykcaldas Nov 20, 2024
ab40b54
Added a description explaining the llm_type attribute
maykcaldas Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from paperqa.agents.models import QueryRequest # noqa: E402
from paperqa.docs import Docs, PQASession, print_callback # noqa: E402
from paperqa.llms import ( # noqa: E402
AnthropicBatchLLMModel,
EmbeddingModel,
HybridEmbeddingModel,
LiteLLMEmbeddingModel,
LiteLLMModel,
LLMModel,
LLMResult,
NumpyVectorStore,
OpenAIBatchLLMModel,
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
embedding_model_factory,
Expand All @@ -28,6 +30,7 @@

__all__ = [
"Answer",
"AnthropicBatchLLMModel",
"Context",
"Doc",
"DocDetails",
Expand All @@ -39,6 +42,7 @@
"LiteLLMEmbeddingModel",
"LiteLLMModel",
"NumpyVectorStore",
"OpenAIBatchLLMModel",
"PQASession",
"QueryRequest",
"SentenceTransformerEmbeddingModel",
Expand Down
8 changes: 6 additions & 2 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
)

from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.llms import (
EmbeddingModel,
LiteLLMModel,
LLMBatchModel,
)
from paperqa.settings import Settings
from paperqa.types import PQASession
from paperqa.utils import get_year
Expand All @@ -37,7 +41,7 @@
def settings_to_tools(
settings: Settings,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | LLMBatchModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
) -> list[Tool]:
"""
Expand Down
83 changes: 81 additions & 2 deletions paperqa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import re
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Any

from paperqa.llms import PromptRunner
Expand Down Expand Up @@ -68,12 +68,19 @@ async def map_fxn_summary(
success = False

if prompt_runner:
llm_result = await prompt_runner(
result = await prompt_runner(
{"question": question, "citation": citation, "text": text.text}
| (extra_prompt_data or {}),
callbacks,
"evidence:" + text.name,
)

if isinstance(result, Sequence) and len(result) != 1:
raise NotImplementedError(
f"Expected a single LLMResult, got {len(result)}. : {result}"
)

llm_result = result if isinstance(result, LLMResult) else result[0]
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
context = llm_result.text
result_data = parser(context) if parser else {}
success = bool(result_data)
Expand Down Expand Up @@ -115,3 +122,75 @@ async def map_fxn_summary(
),
llm_result,
)


async def gather_with_batch(
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved
matches: list[Text],
question: str,
prompt_runner: PromptRunner | None,
extra_prompt_data: dict[str, str] | None = None,
parser: Callable[[str], dict[str, Any]] | None = None,
callbacks: list[Callable[[str], None]] | None = None,
) -> list[tuple[Context, LLMResult]]:
"""
Gathers evidence considering a batch of texts. The completions are obtained using a batch API.

Args:
matches: A list of text matches to gather evidence from.
question: The question to be answered.
prompt_runner: The prompt runner to use for obtaining completions.
extra_prompt_data: Additional data to include in the prompt.
parser: A function to parse the LLM result text.
callbacks: A list of callback functions to be called
with the LLM result text.

Returns:
List of tuples containing the context and LLM result for each match.
"""
data = [
{
"question": question,
"citation": m.name + ": " + m.doc.formatted_citation,
"text": m.text,
}
| (extra_prompt_data or {})
for m in matches
]

llm_results: list[LLMResult] = []
if prompt_runner:
result = await prompt_runner(
data,
callbacks,
"evidence:" + matches[0].name,
)

llm_results = result if isinstance(result, list) else [result]

results_data = []
scores = []
for r in llm_results:
if parser:
res = parser(r.text)
results_data.append(res)
scores.append(res.pop("relevance_score"))
# just in case question was present
res.pop("question", None)
else:
results_data.append({})
scores.append(extract_score(r.text))

return [
(
Context(
context=strip_citations(llm_result.text),
text=m,
score=score,
**r,
),
llm_result,
)
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved
for r, m, llm_result, score in zip(
results_data, matches, llm_results, scores, strict=True
)
]
61 changes: 38 additions & 23 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
)

from paperqa.clients import DEFAULT_CLIENTS, DocMetadataClient
from paperqa.core import llm_parse_json, map_fxn_summary
from paperqa.core import gather_with_batch, llm_parse_json, map_fxn_summary
from paperqa.llms import (
EmbeddingModel,
LLMBatchModel,
LLMModel,
NumpyVectorStore,
PromptRunner,
Expand Down Expand Up @@ -559,14 +560,14 @@ def get_evidence(
)
)

async def aget_evidence(
async def aget_evidence( # noqa: PLR0912
self,
query: PQASession | str,
exclude_text_filter: set[str] | None = None,
settings: MaybeSettings = None,
callbacks: list[Callable] | None = None,
embedding_model: EmbeddingModel | None = None,
summary_llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | LLMBatchModel | None = None,
) -> PQASession:

evidence_settings = get_settings(settings)
Expand Down Expand Up @@ -629,28 +630,42 @@ async def aget_evidence(
)

with set_llm_session_ids(session.id):
results = await gather_with_concurrency(
answer_config.max_concurrent_requests,
[
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
for m in matches
],
)
if evidence_settings.use_batch_in_summary:
results = await gather_with_batch(
matches=matches,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
# citations are formatted inside the function
# for each text in matches
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
else:
results = await gather_with_concurrency(
answer_config.max_concurrent_requests,
[
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
for m in matches
],
)

for _, llm_result in results:
session.add_tokens(llm_result)

session.contexts += [r for r, _ in results if r is not None]
session.contexts += [r for r, _ in results]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we cut the r is not None filter here? I would think that the results from gather_with_concurrency could still be None on failure, but maybe I'm wrong

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets the Contexts from gather_with_concurrency or gather_with_batch. And both always return list of tuples with (Context, LLMResult). What can happen is to have an empty text in Context.text, but it seems to me that r is always an instance of Context.
Also, I didn't see any case of map_fxn_summary returning None while studying the code, and mypy also complains that r is None is always a True statement.

Maybe that's an edge case that I didn't see?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we correctly type hinted gather_with_concurrency then this would be resolved. @maykcaldas can you adjust it to be this?

T = TypeVar("T")


async def gather_with_concurrency(n: int, coros: Iterable[Awaitable[T]]) -> list[T]:
    ...
```

return session

def query(
Expand All @@ -659,7 +674,7 @@ def query(
settings: MaybeSettings = None,
callbacks: list[Callable] | None = None,
llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | LLMBatchModel | None = None,
embedding_model: EmbeddingModel | None = None,
) -> PQASession:
return get_loop().run_until_complete(
Expand All @@ -679,7 +694,7 @@ async def aquery( # noqa: PLR0912
settings: MaybeSettings = None,
callbacks: list[Callable] | None = None,
llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | LLMBatchModel | None = None,
embedding_model: EmbeddingModel | None = None,
) -> PQASession:

Expand Down
Loading