Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support to BatchAPIs to gather evidence #687

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2385dce
Implements OpenAIBatchLLMModel
maykcaldas Nov 14, 2024
8a21055
Incorporates OpenAIBatchLLMModel to get_evidence
maykcaldas Nov 14, 2024
5f59681
Merge branch 'main' into batch_api
maykcaldas Nov 15, 2024
6f7bbb5
Merge branch 'main' into batch_api
maykcaldas Nov 16, 2024
e8dc0d0
Started anthropic batch api support implementation
maykcaldas Nov 15, 2024
899de43
Removed the skip_system argument from the new classes and tests to ma…
maykcaldas Nov 15, 2024
16c3988
Switched to async OpenAI client
maykcaldas Nov 16, 2024
d10a268
Added logging to the batch processing
maykcaldas Nov 16, 2024
0fe9aa1
Created mock server to test openAI batch API
maykcaldas Nov 18, 2024
a9ad540
Implemented batch support to Anthropic
maykcaldas Nov 18, 2024
9a0a6c4
Merge branch 'main' into batch_api
maykcaldas Nov 18, 2024
723650d
Updated uv.lock to include imports for the batch API
maykcaldas Nov 18, 2024
660bfa0
Implements tests with a mocked server for anthropic
maykcaldas Nov 18, 2024
977a025
Added type hints to satisfy the pre-commit
maykcaldas Nov 19, 2024
ee351f2
Merge branch 'main' into batch_api
maykcaldas Nov 19, 2024
293658a
Updates uv on github actions to include extra requirements
maykcaldas Nov 19, 2024
1ad1c7c
Removed the --all-extras flag from uv in github workflow
maykcaldas Nov 19, 2024
af32005
Refactored OpenAiBatchStatus and AnthropicBatchStatus to make the cod…
maykcaldas Nov 19, 2024
63e4b39
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 19, 2024
f61e629
Merge branch 'main' into batch_api
maykcaldas Nov 19, 2024
d7dbd72
Cleaned unneeded comments
maykcaldas Nov 19, 2024
7c37f6d
Updated the way the system message is passed to anthropic
maykcaldas Nov 19, 2024
de18907
changed how the file is passed to openai
maykcaldas Nov 20, 2024
3e72bd4
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 20, 2024
7c7f4b8
Avoided writing to a file when sending the batch to openAi
maykcaldas Nov 20, 2024
6c8f186
Skipped writing a file. Instead, the content is directly passed to th…
maykcaldas Nov 20, 2024
0e43a7c
Merge branch 'main' into batch_api
maykcaldas Nov 20, 2024
17c26eb
Fixed lint error
maykcaldas Nov 20, 2024
c258306
Updated the batch time limit settings name
maykcaldas Nov 20, 2024
4b8e1c3
Removed type hints from docstrings in gather_with_batch
maykcaldas Nov 20, 2024
8b5c1fa
Added exception in map_fxn_summary to treat multiple reponses
maykcaldas Nov 20, 2024
ab40b54
Added a description explaining the llm_type attribute
maykcaldas Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HybridEmbeddingModel,
LiteLLMEmbeddingModel,
LiteLLMModel,
OpenAIBatchLLMModel,
LLMModel,
LLMResult,
NumpyVectorStore,
Expand All @@ -38,6 +39,7 @@
"LLMResult",
"LiteLLMEmbeddingModel",
"LiteLLMModel",
"OpenAIBatchLLMModel"
"NumpyVectorStore",
"PQASession",
"QueryRequest",
Expand Down
78 changes: 62 additions & 16 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
LLMResult,
PQASession,
Text,
Context,
set_llm_session_ids,
)
from paperqa.utils import (
Expand All @@ -50,6 +51,8 @@
maybe_is_text,
md5sum,
name_in_text,
extract_score,
strip_citations
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -600,23 +603,66 @@ async def aget_evidence(
)

with set_llm_session_ids(session.id):
results = await gather_with_concurrency(
answer_config.max_concurrent_requests,
[
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
if evidence_settings.use_batch_in_summary:
# TODO: Should we implement a `gather_with_batch` function that receives `matches` and return results to keep this dry?
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved

data = [
{"question": session.question,
"citation": m.name + ": " + m.doc.formatted_citation,
"text": m.text} |
{"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
"evidence": m.name}
for m in matches
],
)
]

llm_results = await prompt_runner(
data,
callbacks,
)

results_data = []
scores = []
for r in llm_results:
try:
results_data.append(llm_parse_json(r.text))
scores.append(r.pop("relevance_score"))
# just in case question was present
r.pop("question", None)
except ValueError:
results_data.append({})
scores.append(extract_score(r.text))

results = [
(
Context(
context=strip_citations(llm_result.text),
text=m,
model_extra={},
score=score,
**r,
),
llm_result,
) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores)
]
else:
results = await gather_with_concurrency(
answer_config.max_concurrent_requests,
[
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
)
for m in matches
],
)

for _, llm_result in results:
session.add_tokens(llm_result)
Expand Down
209 changes: 208 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from typing import Any, TypeVar, cast

import litellm

import openai
import json
import os
import tempfile

import numpy as np
import tiktoken
from pydantic import (
Expand Down Expand Up @@ -325,7 +331,7 @@ def count_tokens(self, text: str) -> int:
async def run_prompt(
self,
prompt: str,
data: dict,
data: dict | list[dict[str, str]],
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
Expand Down Expand Up @@ -761,6 +767,207 @@ def infer_llm_type(self) -> str:
def count_tokens(self, text: str) -> int:
return litellm.token_counter(model=self.name, text=text)

class OpenAIBatchLLMModel(LLMModel):
"""A wrapper around the OpenAI library to use the batch API."""
name: str = "gpt-4o-mini"
config: dict = Field(
default_factory=dict,
description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.",
)

def write_jsonl(self,
data: list[dict[str, str]],
filename: str):

batch_template = {
"custom_id": None,
"method": "POST",
"url": self.config.get('endpoint'),
"body": {
"model": None,
"messages": None,
"max_tokens": None
}
}
with open(filename, "w") as f:
for i, d in enumerate(data):
batch_template["custom_id"] = str(i)
batch_template["body"]["model"] = self.config.get('model')
batch_template["body"]["messages"] = d
batch_template["body"]["max_tokens"] = self.config.get('max_tokens')
f.write(json.dumps(batch_template) + "\n")

@rate_limited
async def acomplete(self):
raise NotImplementedError("Only chat models are supported by openAI batch API.")

@rate_limited
async def acomplete_iter(self):
raise NotImplementedError("Async generator not supported for batch calls and nly chat models are supported by openAI batch API.")

async def _run_chat(
self,
prompt: str,
data: list[dict[str,str]],
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
system_prompt: str = default_system_prompt,
) -> list[LLMResult]:
if callbacks:
sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]

system_message_prompt = {"role": "system", "content": system_prompt}
human_message_prompt = {"role": "user", "content": prompt}

batch = []
for d in data:
messages = [
{"role": m["role"], "content": m["content"].format(**d)}
for m in (
[human_message_prompt]
if skip_system
else [system_message_prompt, human_message_prompt]
)
]
batch.append(messages)

start_clock = asyncio.get_running_loop().time()
chunks = await self.achat(batch)
batch_time = asyncio.get_running_loop().time() - start_clock

if callbacks:
for chunk in chunks:
await do_callbacks(
async_callbacks, sync_callbacks, chunk.text, name
)

results = [
LLMResult(
model=self.name,
name=name,
prompt=messages,
prompt_count=chunk.prompt_tokens,
text=chunk.text,
completion_count=chunk.completion_tokens,
seconds_to_first_token=batch_time,
seconds_to_last_token=batch_time,
) for messages, chunk in zip(batch, chunks)
]

return results

@rate_limited
async def achat(self,
messages: list[dict[str, str]]
) -> list[Chunk]:
client = openai.OpenAI()

with tempfile.NamedTemporaryFile(suffix=".jsonl", delete=True) as tmp_file:
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
tmp_filename = tmp_file.name
self.write_jsonl(messages, tmp_filename)
file = client.files.create(
file=open(tmp_filename, "rb"),
purpose="batch"
)

batch = client.batches.create(
input_file_id=file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"description": ""
}
)

while batch.status != "completed":
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
batch = client.batches.retrieve(batch.id)
if batch.status == "failed":
raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data]))
await asyncio.sleep(5)
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved

responses = client.files.content(batch.output_file_id)
response_lines = responses.read().decode('utf-8').splitlines()
responses = [json.loads(line) for line in response_lines]
sorted_responses = sorted(responses, key=lambda x: int(x["custom_id"])) # The batchAPI doesn't guarantee the order of the responses

chunks = [
Chunk(
text=response["response"]["body"]["choices"][0]["message"]["content"],
prompt_tokens=response["response"]["body"]["usage"]["prompt_tokens"],
completion_tokens=response["response"]["body"]["usage"]["completion_tokens"],
) for response in sorted_responses
]

return chunks

@rate_limited
async def achat_iter(self):
raise NotImplementedError("Async generator not supported for batch calls. Use achat instead.")

def infer_llm_type(self):
self.config['endpoint'] = "/v1/chat/completions"
return "chat"

def count_tokens(self, text: str) -> int:
return len(text) // 4

async def check_rate_limit(self, token_count: float, **kwargs) -> None:
if "rate_limit" in self.config:
await GLOBAL_LIMITER.try_acquire(
("client", self.name),
self.config["rate_limit"].get(self.name, None),
weight=max(int(token_count), 1),
**kwargs,
)


class AnthropicBatchLLMModel(LLMModel):
# TODO: This class is not implemented yet.

@rate_limited
async def acomplete(self):
raise NotImplementedError("Completion models are not supported yet")

@rate_limited
async def acomplete_iter(self):
raise NotImplementedError("Completion models are not supported yet")

async def _run_chat(sellf):
'''Processes the batch and call the chat completion method'''
...

@rate_limited
async def achat(self, messages):
...

@rate_limited
async def achat_iter(self):
raise NotImplementedError("support to callbacks is not implemented yet")

def infer_llm_type(self):
return "chat" #TODO: Support completion models

def count_tokens(self, text: str) -> int:
return len(text) // 4 #TODO: Check if OpenAI has a method for that. Currently it's not being used. The token usage is directly retrieved from the response.

def __getstate__(self):
# Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953
state = super().__getstate__()
state["__dict__"] = state["__dict__"].copy()
state["__dict__"].pop("_router", None)
return state

async def check_rate_limit(self, token_count: float, **kwargs) -> None:
if "rate_limit" in self.config:
await GLOBAL_LIMITER.try_acquire(
("client", self.name),
self.config["rate_limit"].get(self.name, None),
weight=max(int(token_count), 1),
**kwargs,
)


def cosine_similarity(a, b):
norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
Expand Down
25 changes: 24 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
HAS_LDP_INSTALLED = False

from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory
from paperqa.llms import EmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, embedding_model_factory
from paperqa.prompts import (
CONTEXT_INNER_PROMPT,
CONTEXT_OUTER_PROMPT,
Expand Down Expand Up @@ -577,6 +577,15 @@ def make_default_litellm_model_list_settings(
]
}

def make_default_openai_batch_llm_settings(
llm: str, temperature: float = 0.0
) -> dict:
return {
"model": llm,
"temperature": temperature,
"max_tokens": 2048,

}

class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
Expand Down Expand Up @@ -609,6 +618,10 @@ class Settings(BaseSettings):
" router_kwargs key with router kwargs as values."
),
)
use_batch_in_summary: bool = Field(
default=False,
description="Whether to use batch API for LLMs in summarization",
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
)
embedding: str = Field(
default="text-embedding-3-small",
description="Default embedding model for texts",
Expand Down Expand Up @@ -793,6 +806,16 @@ def get_llm(self) -> LiteLLMModel:
)

def get_summary_llm(self) -> LiteLLMModel:
if self.use_batch_in_summary:
# TODO: support other LLM providers as well.
# TODO: Make it fail if we don't support the batchAPI for the LLM being used
return OpenAIBatchLLMModel(
name=self.summary_llm,
config=self.summary_llm_config
or make_default_openai_batch_llm_settings(
self.summary_llm, self.temperature
),
)
return LiteLLMModel(
name=self.summary_llm,
config=self.summary_llm_config
Expand Down
Loading
Loading