diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 008b1825..ab8ea15b 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -11,6 +11,7 @@ from paperqa.agents.models import QueryRequest # noqa: E402 from paperqa.docs import Docs, PQASession, print_callback # noqa: E402 from paperqa.llms import ( # noqa: E402 + AnthropicBatchLLMModel, EmbeddingModel, HybridEmbeddingModel, LiteLLMEmbeddingModel, @@ -18,6 +19,7 @@ LLMModel, LLMResult, NumpyVectorStore, + OpenAIBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, @@ -28,6 +30,7 @@ __all__ = [ "Answer", + "AnthropicBatchLLMModel", "Context", "Doc", "DocDetails", @@ -39,6 +42,7 @@ "LiteLLMEmbeddingModel", "LiteLLMModel", "NumpyVectorStore", + "OpenAIBatchLLMModel", "PQASession", "QueryRequest", "SentenceTransformerEmbeddingModel", diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 8d846c02..535f0ada 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -13,7 +13,11 @@ ) from paperqa.docs import Docs -from paperqa.llms import EmbeddingModel, LiteLLMModel +from paperqa.llms import ( + EmbeddingModel, + LiteLLMModel, + LLMBatchModel, +) from paperqa.settings import Settings from paperqa.types import PQASession from paperqa.utils import get_year @@ -37,7 +41,7 @@ def settings_to_tools( settings: Settings, llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, - summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, + summary_llm_model: LiteLLMModel | LLMBatchModel | None = POPULATE_FROM_SETTINGS, embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS, ) -> list[Tool]: """ diff --git a/paperqa/core.py b/paperqa/core.py index 5ceb0060..a545da23 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -2,7 +2,7 @@ import json import re -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Any from paperqa.llms import PromptRunner @@ -68,12 +68,19 @@ async def map_fxn_summary( success = False if prompt_runner: - llm_result = await prompt_runner( + result = await prompt_runner( {"question": question, "citation": citation, "text": text.text} | (extra_prompt_data or {}), callbacks, "evidence:" + text.name, ) + + if isinstance(result, Sequence) and len(result) != 1: + raise NotImplementedError( + f"Expected a single LLMResult, got {len(result)}. : {result}" + ) + + llm_result = result if isinstance(result, LLMResult) else result[0] context = llm_result.text result_data = parser(context) if parser else {} success = bool(result_data) @@ -115,3 +122,75 @@ async def map_fxn_summary( ), llm_result, ) + + +async def gather_with_batch( + matches: list[Text], + question: str, + prompt_runner: PromptRunner | None, + extra_prompt_data: dict[str, str] | None = None, + parser: Callable[[str], dict[str, Any]] | None = None, + callbacks: list[Callable[[str], None]] | None = None, +) -> list[tuple[Context, LLMResult]]: + """ + Gathers evidence considering a batch of texts. The completions are obtained using a batch API. + + Args: + matches: A list of text matches to gather evidence from. + question: The question to be answered. + prompt_runner: The prompt runner to use for obtaining completions. + extra_prompt_data: Additional data to include in the prompt. + parser: A function to parse the LLM result text. + callbacks: A list of callback functions to be called + with the LLM result text. + + Returns: + List of tuples containing the context and LLM result for each match. + """ + data = [ + { + "question": question, + "citation": m.name + ": " + m.doc.formatted_citation, + "text": m.text, + } + | (extra_prompt_data or {}) + for m in matches + ] + + llm_results: list[LLMResult] = [] + if prompt_runner: + result = await prompt_runner( + data, + callbacks, + "evidence:" + matches[0].name, + ) + + llm_results = result if isinstance(result, list) else [result] + + results_data = [] + scores = [] + for r in llm_results: + if parser: + res = parser(r.text) + results_data.append(res) + scores.append(res.pop("relevance_score")) + # just in case question was present + res.pop("question", None) + else: + results_data.append({}) + scores.append(extract_score(r.text)) + + return [ + ( + Context( + context=strip_citations(llm_result.text), + text=m, + score=score, + **r, + ), + llm_result, + ) + for r, m, llm_result, score in zip( + results_data, matches, llm_results, scores, strict=True + ) + ] diff --git a/paperqa/docs.py b/paperqa/docs.py index 949cd922..1d2a4bf7 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -22,9 +22,10 @@ ) from paperqa.clients import DEFAULT_CLIENTS, DocMetadataClient -from paperqa.core import llm_parse_json, map_fxn_summary +from paperqa.core import gather_with_batch, llm_parse_json, map_fxn_summary from paperqa.llms import ( EmbeddingModel, + LLMBatchModel, LLMModel, NumpyVectorStore, PromptRunner, @@ -559,14 +560,14 @@ def get_evidence( ) ) - async def aget_evidence( + async def aget_evidence( # noqa: PLR0912 self, query: PQASession | str, exclude_text_filter: set[str] | None = None, settings: MaybeSettings = None, callbacks: list[Callable] | None = None, embedding_model: EmbeddingModel | None = None, - summary_llm_model: LLMModel | None = None, + summary_llm_model: LLMModel | LLMBatchModel | None = None, ) -> PQASession: evidence_settings = get_settings(settings) @@ -629,28 +630,42 @@ async def aget_evidence( ) with set_llm_session_ids(session.id): - results = await gather_with_concurrency( - answer_config.max_concurrent_requests, - [ - map_fxn_summary( - text=m, - question=session.question, - prompt_runner=prompt_runner, - extra_prompt_data={ - "summary_length": answer_config.evidence_summary_length, - "citation": f"{m.name}: {m.doc.formatted_citation}", - }, - parser=llm_parse_json if prompt_config.use_json else None, - callbacks=callbacks, - ) - for m in matches - ], - ) + if evidence_settings.use_batch_in_summary: + results = await gather_with_batch( + matches=matches, + question=session.question, + prompt_runner=prompt_runner, + extra_prompt_data={ + "summary_length": answer_config.evidence_summary_length, + # citations are formatted inside the function + # for each text in matches + }, + parser=llm_parse_json if prompt_config.use_json else None, + callbacks=callbacks, + ) + else: + results = await gather_with_concurrency( + answer_config.max_concurrent_requests, + [ + map_fxn_summary( + text=m, + question=session.question, + prompt_runner=prompt_runner, + extra_prompt_data={ + "summary_length": answer_config.evidence_summary_length, + "citation": f"{m.name}: {m.doc.formatted_citation}", + }, + parser=llm_parse_json if prompt_config.use_json else None, + callbacks=callbacks, + ) + for m in matches + ], + ) for _, llm_result in results: session.add_tokens(llm_result) - session.contexts += [r for r, _ in results if r is not None] + session.contexts += [r for r, _ in results] return session def query( @@ -659,7 +674,7 @@ def query( settings: MaybeSettings = None, callbacks: list[Callable] | None = None, llm_model: LLMModel | None = None, - summary_llm_model: LLMModel | None = None, + summary_llm_model: LLMModel | LLMBatchModel | None = None, embedding_model: EmbeddingModel | None = None, ) -> PQASession: return get_loop().run_until_complete( @@ -679,7 +694,7 @@ async def aquery( # noqa: PLR0912 settings: MaybeSettings = None, callbacks: list[Callable] | None = None, llm_model: LLMModel | None = None, - summary_llm_model: LLMModel | None = None, + summary_llm_model: LLMModel | LLMBatchModel | None = None, embedding_model: EmbeddingModel | None = None, ) -> PQASession: diff --git a/paperqa/llms.py b/paperqa/llms.py index b8b62805..f1a42d44 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -3,6 +3,9 @@ import asyncio import contextlib import functools +import io +import json +import logging from abc import ABC, abstractmethod from collections.abc import ( AsyncGenerator, @@ -16,7 +19,7 @@ from enum import StrEnum from inspect import isasyncgenfunction, signature from sys import version_info -from typing import Any, TypeVar, cast +from typing import Any, TypedDict, TypeVar, cast import litellm import numpy as np @@ -36,9 +39,11 @@ from paperqa.types import Embeddable, LLMResult from paperqa.utils import is_coroutine_callable +logger = logging.getLogger(__name__) + PromptRunner = Callable[ - [dict, list[Callable[[str], None]] | None, str | None], - Awaitable[LLMResult], + [dict | list[dict], list[Callable[[str], None]] | None, str | None], + Awaitable[LLMResult | list[LLMResult]], ] MODEL_COST_MAP = litellm.get_model_cost_map("") @@ -70,6 +75,39 @@ class EmbeddingModes(StrEnum): QUERY = "query" +class BatchStatus(StrEnum): + COMPLETE = "complete" + PROGRESS = "progress" + SUCCESS = "success" + FAILURE = "failure" + EXPIRE = "expire" + CANCEL = "cancel" + + def from_openai(self) -> str: + """Convert BatchStatus to OpenAI status.""" + mapping = { + BatchStatus.COMPLETE: "completed", + BatchStatus.PROGRESS: "in_progress", + BatchStatus.SUCCESS: "completed", + BatchStatus.FAILURE: "failed", + BatchStatus.EXPIRE: "expired", + BatchStatus.CANCEL: "cancelled", + } + return mapping[self] + + def from_anthropic(self) -> str: + """Convert BatchStatus to Anthropic status.""" + mapping = { + BatchStatus.COMPLETE: "ended", + BatchStatus.PROGRESS: "in_progress", + BatchStatus.SUCCESS: "succeeded", + BatchStatus.FAILURE: "errored", + BatchStatus.EXPIRE: "expired", + BatchStatus.CANCEL: "canceled", + } + return mapping[self] + + # Estimate from OpenAI's FAQ # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them CHARACTERS_PER_TOKEN_ASSUMPTION: float = 4.0 @@ -275,7 +313,10 @@ def __str__(self): class LLMModel(ABC, BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - llm_type: str | None = None + llm_type: str | StrEnum | None = Field( + default=None, + description="A string indicating the type of LLM model (e.g., 'chat' or 'completion').", + ) name: str llm_result_callback: ( Callable[[LLMResult], None] | Callable[[LLMResult], Awaitable[None]] | None @@ -764,6 +805,380 @@ async def select_tool( return await tool_selector(*selection_args, **selection_kwargs) +class LLMBatchModel(ABC, BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + llm_type: str | StrEnum | None = Field( + default=None, + description="A string indicating the type of LLM model (e.g., 'chat' or 'completion').", + ) + name: str + llm_result_callback: ( + Callable[[LLMResult], None] | Callable[[LLMResult], Awaitable[None]] | None + ) = Field( + default=None, + description=( + "An async callback that will be executed on each" + " LLMResult (different than callbacks that execute on each chunk)" + ), + exclude=True, + ) + config: dict = Field(default_factory=dict) + + async def run_prompt( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + if self.llm_type is None: + self.llm_type = self.infer_llm_type() + if self.llm_type == "chat": + return await self._run_chat(prompt, data, callbacks, name, system_prompt) + if self.llm_type == "completion": + return await self._run_completion( + prompt, data, callbacks, name, system_prompt + ) + raise ValueError(f"Unknown llm_type {self.llm_type!r}.") + + async def _run_chat( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + raise NotImplementedError + + async def _run_completion( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + raise NotImplementedError + + def infer_llm_type(self) -> str: + return "chat" + + def count_tokens(self, text: str) -> int: + return len(text) // 4 + + +class Body(TypedDict): + model: str | None + messages: list[dict[str, str]] | None + max_tokens: int | None + + +class BatchTemplate(TypedDict): + custom_id: str | None + method: str + url: str + body: Body + + +class OpenAIBatchLLMModel(LLMBatchModel): + """A wrapper around the OpenAI library to use the batch API.""" + + name: str = "gpt-4o-mini" + config: dict = Field( + default_factory=dict, + description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", + ) + + async def write_jsonl( + self, + data: list[list[dict[str, str]]], + mem_buffer: io.BytesIO, + ): + batch_template: BatchTemplate = { + "custom_id": None, + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": None, "messages": None, "max_tokens": None}, + } + + for i, d in enumerate(data): + batch_template["custom_id"] = str(i) + batch_template["body"]["model"] = self.config.get("model") + batch_template["body"]["messages"] = d + batch_template["body"]["max_tokens"] = self.config.get("max_tokens") + serialized_data = json.dumps(batch_template) + "\n" + mem_buffer.write(serialized_data.encode()) + + async def acomplete(self): + raise NotImplementedError("Only chat models are supported by openAI batch API.") + + async def acomplete_iter(self): + raise NotImplementedError( + "Async generator not supported for batch calls and nly chat models are supported by openAI batch API." + ) + + async def _run_chat( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + if callbacks: + sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + + human_message_prompt = {"role": "user", "content": prompt} + + batch = [] + for d in data: + messages = [ + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] + ) + ] + batch.append(messages) + + start_clock = asyncio.get_running_loop().time() + chunks = await self.achat(batch) + batch_time = asyncio.get_running_loop().time() - start_clock + + if callbacks: + for chunk in chunks: + await do_callbacks( + async_callbacks, sync_callbacks, chunk.text or "", name + ) + + return [ + LLMResult( + model=self.name, + name=name, + prompt=messages, + prompt_count=chunk.prompt_tokens, + text=chunk.text, + completion_count=chunk.completion_tokens, + seconds_to_first_token=batch_time, + seconds_to_last_token=batch_time, + ) + for messages, chunk in zip(batch, chunks, strict=True) + ] + + async def achat(self, messages: list[list[dict]]) -> list[Chunk]: + try: + import openai + except ImportError as exc: + raise ImportError( + "Please install paper-qa[batch] to use OpenAIBatchLLMModel." + ) from exc + + client = openai.AsyncOpenAI() + + with io.BytesIO() as mem_buffer: + await self.write_jsonl(messages, mem_buffer) + mem_buffer.seek(0) + file = await client.files.create(file=mem_buffer, purpose="batch") + + batch = await client.batches.create( + input_file_id=file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"description": ""}, + ) + + start_clock = asyncio.get_running_loop().time() + while batch.status != BatchStatus.COMPLETE.from_openai(): + batch = await client.batches.retrieve(batch.id) + if batch.status == BatchStatus.FAILURE.from_openai(): + error_messages = [] + if batch.errors and hasattr(batch.errors, "data") and batch.errors.data: + error_messages = [ + str(k.message) + for k in batch.errors.data + if k.message is not None + ] + raise RuntimeError( + "Batch failed. \n\nReason: \n" + "\n".join(error_messages) + ) + if batch.status == BatchStatus.CANCEL.from_openai(): + raise ConnectionError("Batch was cancelled.") + + batch_time = asyncio.get_running_loop().time() - start_clock + if batch_time > self.config.get("batch_summary_time_limit", 24 * 60 * 60): + raise TimeoutError("Batch took too long to complete.") + + logger.info( + f"Summary batch status: {batch.status} | Time elapsed: {batch_time}" + ) + await asyncio.sleep(self.config.get("batch_polling_interval", 30)) + + if batch.output_file_id: + api_responses = await client.files.content(batch.output_file_id) + else: + raise RuntimeError("Batch failed to generate output file.") + sorted_responses = sorted( + [ + json.loads(line) + for line in api_responses.read().decode("utf-8").splitlines() + ], + key=lambda x: int(x["custom_id"]), + ) # The batchAPI doesn't guarantee the order of the responses + + return [ + Chunk( + text=response["response"]["body"]["choices"][0]["message"]["content"], + prompt_tokens=response["response"]["body"]["usage"]["prompt_tokens"], + completion_tokens=response["response"]["body"]["usage"][ + "completion_tokens" + ], + ) + for response in sorted_responses + ] + + async def achat_iter(self): + raise NotImplementedError( + "Async generator not supported for batch calls. Use achat instead." + ) + + +class AnthropicBatchLLMModel(LLMBatchModel): + """A wrapper around the anthropic library to use the batch API.""" + + name: str = "claude-3-5-sonnet-latest" + config: dict = Field( + default_factory=dict, + description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", + ) + + async def acomplete(self): + raise NotImplementedError("Completion models are not supported yet") + + async def acomplete_iter(self): + raise NotImplementedError("Completion models are not supported yet") + + async def _run_chat( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + if callbacks: + sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + + human_message_prompt = {"role": "user", "content": prompt} + + batch = [] + for d in data: + messages = [ + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] + ) + ] + batch.append(messages) + + start_clock = asyncio.get_running_loop().time() + chunks = await self.achat(batch) + batch_time = asyncio.get_running_loop().time() - start_clock + + if callbacks: + for chunk in chunks: + await do_callbacks( + async_callbacks, sync_callbacks, chunk.text or "", name + ) + + return [ + LLMResult( + model=self.name, + name=name, + prompt=messages, + prompt_count=chunk.prompt_tokens, + text=chunk.text, + completion_count=chunk.completion_tokens, + seconds_to_first_token=batch_time, + seconds_to_last_token=batch_time, + ) + for messages, chunk in zip(batch, chunks, strict=True) + ] + + async def achat(self, messages: list[list[dict[str, str]]]) -> list[Chunk]: + try: + import anthropic + from anthropic.types.beta.message_create_params import ( + MessageCreateParamsNonStreaming, + ) + from anthropic.types.beta.messages.batch_create_params import Request + except ImportError as exc: + raise ImportError( + "Please install paper-qa[batch] to use AnthropicBatchLLMModel." + ) from exc + + client = anthropic.AsyncAnthropic() + + requests = [ + Request( + custom_id=str(i), + params=MessageCreateParamsNonStreaming( + model=self.config.get("model"), + max_tokens=self.config.get("max_tokens"), + system="".join( + [ + user_m["content"] + for user_m in messages[0] + if user_m["role"] == "system" + ] + ), + messages=[user_m for user_m in m if user_m["role"] == "user"], + ), + ) + for i, m in enumerate(messages) + ] + + batch = await client.beta.messages.batches.create(requests=requests) + + start_clock = asyncio.get_running_loop().time() + while batch.processing_status != BatchStatus.COMPLETE.from_anthropic(): + batch = await client.beta.messages.batches.retrieve(batch.id) + + batch_time = asyncio.get_running_loop().time() - start_clock + if batch_time > self.config.get("batch_summary_time_limit", 24 * 60 * 60): + raise TimeoutError("Batch took too long to complete.") + + logger.info( + f"Summary batch status: {batch.processing_status} | Time elapsed: {batch_time}" + ) + await asyncio.sleep(self.config.get("batch_polling_interval", 30)) + + api_responses = await client.beta.messages.batches.results(batch.id) + responses = list(api_responses) + sorted_responses = sorted( + responses, key=lambda x: int(x.custom_id) + ) # The batchAPI doesn't guarantee the order of the responses + + return [ + Chunk( + text=response.result.message.content[0].text, + prompt_tokens=response.result.message.usage.input_tokens, + completion_tokens=response.result.message.usage.output_tokens, + ) + for response in sorted_responses + ] + + async def achat_iter(self): + raise NotImplementedError("support to callbacks is not implemented yet") + + def cosine_similarity(a, b): norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) return a @ b.T / norm_product diff --git a/paperqa/settings.py b/paperqa/settings.py index bf7749e2..90978dd5 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -40,7 +40,14 @@ except ImportError: HAS_LDP_INSTALLED = False -from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory +from paperqa.llms import ( + AnthropicBatchLLMModel, + EmbeddingModel, + LiteLLMModel, + LLMBatchModel, + OpenAIBatchLLMModel, + embedding_model_factory, +) from paperqa.prompts import ( CONTEXT_INNER_PROMPT, CONTEXT_OUTER_PROMPT, @@ -581,6 +588,14 @@ def make_default_litellm_model_list_settings( } +def make_default_openai_batch_llm_settings(llm: str, temperature: float = 0.0) -> dict: + return { + "model": llm, + "temperature": temperature, + "max_tokens": 2048, + } + + class Settings(BaseSettings): model_config = SettingsConfigDict(extra="ignore") @@ -612,6 +627,27 @@ class Settings(BaseSettings): " router_kwargs key with router kwargs as values." ), ) + use_batch_in_summary: bool = Field( + default=False, + description=( + "Whether to use batch API for LLMs in summarization, " + "which means multiple messages are sent in one API request " + "to the LLM provider's batch API." + "This option is only available for Claude(https://docs.anthropic.com/en/api/creating-message-batches)" + "and OpenAI (https://platform.openai.com/docs/guides/batch) chat models." + ), + ) + batch_summary_time_limit: int = Field( + default=24 * 60 * 60, + description=( + "Time limit for batch summarization in seconds. " + "Default is set to 24 hours to match OpenAI's and Anthropic's limit." + ), + ) + batch_polling_interval: int = Field( + default=30, + description="Polling interval for batch summarization in seconds", + ) embedding: str = Field( default="text-embedding-3-small", description="Default embedding model for texts", @@ -795,7 +831,37 @@ def get_llm(self) -> LiteLLMModel: or make_default_litellm_model_list_settings(self.llm, self.temperature), ) - def get_summary_llm(self) -> LiteLLMModel: + def get_summary_llm(self) -> LiteLLMModel | LLMBatchModel: + if self.use_batch_in_summary: + import openai + + client = openai.OpenAI() + openai_models = [ + k.id + for k in client.models.list().data + if k.owned_by in {"system", "openai"} + ] + if self.summary_llm.startswith("claude-"): + return AnthropicBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) + if self.summary_llm in openai_models: + return OpenAIBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) + raise NotImplementedError( + "`use_batch_in_summary` is set to True, but the summary LLM is not supported" + "for batch processing.\nEither use a Claude or an OpenAI chat model or set " + "`use_batch_in_summary` to False." + ) return LiteLLMModel( name=self.summary_llm, config=self.summary_llm_config diff --git a/pyproject.toml b/pyproject.toml index ac7f3f68..a3809e83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = ["setuptools>=64", "setuptools_scm>=8"] dev = [ "ipython>=8", # Pin to keep recent "mypy>=1.8", # Pin for mutable-override - "paper-qa[datasets,ldp,typing,zotero,local]", + "paper-qa[datasets,batch,ldp,typing,zotero,local]", "pre-commit>=3.4", # Pin to keep recent "pydantic~=2.0", "pylint-pydantic", @@ -71,6 +71,10 @@ readme = "README.md" requires-python = ">=3.11" [project.optional-dependencies] +batch = [ + "anthropic", + "openai", +] datasets = [ "datasets", ] diff --git a/tests/test_llms.py b/tests/test_llms.py index 69bd65c8..ff12efeb 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -1,20 +1,28 @@ +import json import pathlib import pickle from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch +import anthropic import litellm +import openai import pytest from paperqa import ( + AnthropicBatchLLMModel, HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, + OpenAIBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, ) -from paperqa.llms import Chunk +from paperqa.llms import ( + BatchStatus, + Chunk, +) from tests.conftest import VCR_DEFAULT_MATCH_ON @@ -159,6 +167,420 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None: assert llm.router.deployment_names == rehydrated_llm.router.deployment_names +class TestOpenAIBatchLLMModel: + @pytest.fixture(scope="class") + def config(self, request) -> dict[str, Any]: + model_name = request.param + return { + "model": model_name, + "temperature": 0.0, + "max_tokens": 64, + "batch_summary_time_limit": 24 * 60 * 60, + "batch_polling_interval": 5, + } + + @pytest.mark.parametrize( + "config", + [ + pytest.param("gpt-4o-mini", id="chat-model"), + pytest.param("gpt-3.5-turbo-instruct", id="completion-model"), + ], + indirect=True, + ) + @pytest.mark.asyncio + async def test_run_prompt(self, config: dict[str, Any], request) -> None: + + mock_client = AsyncMock(spec_set=openai.AsyncOpenAI()) + + mock_file_id = "file-123" + mock_client.files.create = AsyncMock(return_value=MagicMock(id=mock_file_id)) + + mock_batch_id = "batch_123" + mock_client.batches.create = AsyncMock( + return_value=MagicMock( + id=mock_batch_id, status=BatchStatus.PROGRESS.from_openai() + ) + ) + + if request.node.name == "test_run_prompt[completion-model]": + batch_retrieve_calls = [ + MagicMock( + id=mock_batch_id, + status=BatchStatus.FAILURE.from_openai(), + errors=MagicMock( + data=[ + MagicMock( + message=( + "Batch failed: The model gpt-3.5-turbo-instruct " + "is not supported for batch completions." + ) + ) + ] + ), + ), + ] + elif request.node.name == "test_run_prompt[chat-model]": + batch_retrieve_calls = [ + MagicMock(id=mock_batch_id, status=BatchStatus.PROGRESS.from_openai()), + MagicMock( + id=mock_batch_id, + status=BatchStatus.COMPLETE.from_openai(), + output_file_id="file-789", + ), + ] + mock_client.batches.retrieve = AsyncMock(side_effect=batch_retrieve_calls) + + sample_responses = [ + { + "id": "file-789", + "custom_id": "0", + "response": { + "body": { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ( + 'The duck says "quack." This vocalization is characteristic of the species ' + "Anas platyrhynchos, commonly known as the mallard duck, which is often used " + "as a representative example for the duck family, Anatidae." + ), + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 47, + "total_tokens": 93, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + } + }, + }, + { + "id": "file-789", + "custom_id": "1", + "response": { + "body": { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ( + 'The dog says "bark." This is a vocalization ' + "commonly associated with canines, used for " + "communication purposes such as alerting, expressing " + "excitement, or seeking attention." + ), + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 34, + "total_tokens": 80, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + } + }, + }, + { + "id": "file-789", + "custom_id": "2", + "response": { + "body": { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ( + 'It seems you\'re quoting or referencing "the cat says." ' + "If you're looking for a specific context, such as a phrase, a song, " + "or a scientific observation (like feline vocalizations), please provide " + "more details for a precise response." + ), + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 46, + "total_tokens": 92, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + } + }, + }, + ] + + response_data = "\n".join(json.dumps(resp) for resp in sample_responses) + mock_response_content = MagicMock() + mock_response_content.read.return_value = response_data.encode() + mock_client.files.content = AsyncMock(return_value=mock_response_content) + + with patch("openai.AsyncOpenAI", return_value=mock_client): + llm = OpenAIBatchLLMModel(name=config["model"], config=config) + + outputs = [] + + def accum(x) -> None: + outputs.append(x) + + async def ac(x) -> None: + pass + + data = [{"animal": "duck"}, {"animal": "dog"}, {"animal": "cat"}] + + if request.node.name == "test_run_prompt[completion-model]": + with pytest.raises(RuntimeError) as e_info: + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + ) + assert "Batch failed" in str(e_info.value) + assert "not supported" in str(e_info.value) + + if request.node.name == "test_run_prompt[chat-model]": + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + callbacks=[accum, ac], + ) + + assert all( + completion[k].model == config["model"] for k in range(len(data)) + ) + assert all( + completion[k].seconds_to_first_token > 0 for k in range(len(data)) + ) + assert all(completion[k].prompt_count > 0 for k in range(len(data))) + assert all(completion[k].completion_count > 0 for k in range(len(data))) + assert all( + completion[k].completion_count <= config["max_tokens"] + for k in range(len(data)) + ) + assert sum(comp.cost for comp in completion) > 0 + assert all(str(completion[k]) == outputs[k] for k in range(len(data))) + + @pytest.mark.parametrize( + "config", + [ + pytest.param("gpt-4o-mini"), + ], + indirect=True, + ) + def test_pickling(self, tmp_path: pathlib.Path, config: dict[str, Any]) -> None: + pickle_path = tmp_path / "llm_model.pickle" + llm = OpenAIBatchLLMModel( + name="gpt-4o-mini", + config=config, + ) + with pickle_path.open("wb") as f: + pickle.dump(llm, f) + with pickle_path.open("rb") as f: + rehydrated_llm = pickle.load(f) + assert llm.name == rehydrated_llm.name + assert llm.config == rehydrated_llm.config + + +class TestAnthropicBatchLLMModel: + @pytest.fixture(scope="class") + def config(self, request) -> dict[str, Any]: + model_name = request.param + return { + "model": model_name, + "temperature": 0.0, + "max_tokens": 64, + "batch_summary_time_limit": 24 * 60 * 60, + "batch_polling_interval": 5, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "config", + [ + pytest.param("claude-3-haiku-20240307", id="chat-model"), + ], + indirect=True, + ) + async def test_run_prompt(self, config: dict[str, Any]) -> None: + + mock_client = AsyncMock(spec_set=anthropic.AsyncAnthropic()) + + mock_client = MagicMock() + mock_batches = MagicMock() + mock_client.beta.messages.batches = mock_batches + + mock_batch_id = "msgbatch_123" + mock_batches.create = AsyncMock( + return_value=MagicMock( + id=mock_batch_id, + processing_status=BatchStatus.PROGRESS.from_anthropic(), + ), + ) + + batch_retrieve_call = [ + MagicMock( + id=mock_batch_id, + processing_status=BatchStatus.PROGRESS.from_anthropic(), + ), + MagicMock( + id=mock_batch_id, + processing_status=BatchStatus.COMPLETE.from_anthropic(), + ), + ] + mock_batches.retrieve = AsyncMock(side_effect=batch_retrieve_call) + + mock_responses = [ + MagicMock( + custom_id="0", + result=MagicMock( + message=MagicMock( + id="msg_0143L9rPswgaUyENkHkPJLcn", + content=[ + MagicMock( + text=( + "I don't actually hear any ducks saying anything. " + "As an AI assistant, I don't have the ability to hear or interpret " + "sounds from the physical world. I can only respond based on the text " + "you provide to me through this chat interface. " + "If you'd like, you can tell me what you think the duck is" + ), + ) + ], + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="max_tokens", + stop_sequence=None, + type="message", + usage=MagicMock(input_tokens=10, output_tokens=64), + ), + type="succeeded", + ), + ), + MagicMock( + custom_id="1", + result=MagicMock( + message=MagicMock( + id="msg_01KujiHEB5S8pfRUCmrbabu4", + content=[ + MagicMock( + text=( + "Unfortunately, I don't actually hear a dog speaking. " + "As an AI assistant without physical senses, I" + "can't directly perceive animals making sounds. " + "Could you please provide more context about what the " + "dog is saying, or what you would like me to respond to " + "regarding the dog? I'd be happy to try to assist" + ), + ) + ], + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="max_tokens", + stop_sequence=None, + type="message", + usage=MagicMock(input_tokens=10, output_tokens=64), + ), + type="succeeded", + ), + ), + MagicMock( + custom_id="2", + result=MagicMock( + message=MagicMock( + id="msg_01Pf2LqV7wjnwqerkZubbofA", + content=[ + MagicMock( + text=( + "I'm afraid I don't actually hear a cat speaking. " + "As an AI assistant, I don't have the ability to hear " + "or communicate with animals directly. I can only respond " + "based on the text you provide to me. If you'd " + "like, you can tell me what you imagine the cat is saying, and I'll" + ), + ) + ], + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="max_tokens", + stop_sequence=None, + type="message", + usage=MagicMock(input_tokens=10, output_tokens=64), + ), + type="succeeded", + ), + ), + ] + + # Create a generator function + def mock_results_generator(_batch_id): + + yield from mock_responses + + mock_batches.results = AsyncMock( + return_value=mock_results_generator(mock_batch_id) + ) + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + llm = AnthropicBatchLLMModel(name=config["model"], config=config) + + data = [{"animal": "duck"}, {"animal": "dog"}, {"animal": "cat"}] + + completions = await llm.run_prompt( + prompt="The {animal} says", + data=data, + ) + + assert all(comp.model == config["model"] for comp in completions) + assert all(comp.seconds_to_first_token > 0 for comp in completions) + assert all(comp.prompt_count > 0 for comp in completions) + assert all(comp.completion_count > 0 for comp in completions) + assert all( + comp.completion_count <= config["max_tokens"] for comp in completions + ) + assert sum(comp.cost for comp in completions) > 0 + + @pytest.mark.asyncio async def test_embedding_model_factory_sentence_transformer() -> None: """Test that the factory creates a SentenceTransformerEmbeddingModel when given an 'st-' prefix.""" diff --git a/uv.lock b/uv.lock index a45f5a87..6efb9abc 100644 --- a/uv.lock +++ b/uv.lock @@ -110,6 +110,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, ] +[[package]] +name = "anthropic" +version = "0.39.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/02/2ea51930009d7537c4648f51d1bb3202ec76704cbb39a2a863ab38bee3dd/anthropic-0.39.0.tar.gz", hash = "sha256:94671cc80765f9ce693f76d63a97ee9bef4c2d6063c044e983d21a2e262f63ba", size = 189339 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/61/2580eaa171cab20708d59d39cadd15f78a6c617759e8d0a12e18fe3302d1/anthropic-0.39.0-py3-none-any.whl", hash = "sha256:ea17093ae0ce0e1768b0c46501d6086b5bcd74ff39d68cd2d6396374e9de7c09", size = 198392 }, +] + [[package]] name = "anyio" version = "4.6.2.post1" @@ -1511,7 +1529,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.4.1.dev16+ga004d22" +version = "5.4.1.dev42+gd7dbd72.d20241119" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1535,6 +1553,10 @@ dependencies = [ ] [package.optional-dependencies] +batch = [ + { name = "anthropic" }, + { name = "openai" }, +] datasets = [ { name = "datasets" }, ] @@ -1555,10 +1577,12 @@ zotero = [ [package.dev-dependencies] dev = [ + { name = "anthropic" }, { name = "datasets" }, { name = "ipython" }, { name = "ldp" }, { name = "mypy" }, + { name = "openai" }, { name = "pandas-stubs" }, { name = "pre-commit" }, { name = "pydantic" }, @@ -1583,6 +1607,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp" }, + { name = "anthropic", marker = "extra == 'batch'" }, { name = "anyio" }, { name = "coredis" }, { name = "datasets", marker = "extra == 'datasets'" }, @@ -1593,6 +1618,7 @@ requires-dist = [ { name = "limits" }, { name = "litellm", specifier = ">=1.44" }, { name = "numpy" }, + { name = "openai", marker = "extra == 'batch'" }, { name = "pandas-stubs", marker = "extra == 'typing'" }, { name = "pybtex" }, { name = "pydantic", specifier = "~=2.0" }, @@ -1613,7 +1639,7 @@ requires-dist = [ dev = [ { name = "ipython", specifier = ">=8" }, { name = "mypy", specifier = ">=1.8" }, - { name = "paper-qa", extras = ["datasets", "ldp", "typing", "zotero", "local"] }, + { name = "paper-qa", extras = ["datasets", "batch", "ldp", "typing", "zotero", "local"] }, { name = "pre-commit", specifier = ">=3.4" }, { name = "pydantic", specifier = "~=2.0" }, { name = "pylint-pydantic" },