Skip to content

Commit

Permalink
Make OpenAI, Antrhopic calls async by default
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Oct 17, 2024
1 parent 0f4f51b commit d97af1f
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 57 deletions.
6 changes: 6 additions & 0 deletions blendsql/db/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ class Database(ABC):
db_url: Union[URL, str] = attrib()
lazy_tables: LazyTables = LazyTables()

def __str__(self):
return f"{self.__class__} @ {self.db_url}"

def __repr__(self):
return f"{self.__class__} @ {self.db_url}"

@abstractmethod
def _reset_connection(self) -> None:
"""Reset connection, so that temp tables are cleared."""
Expand Down
2 changes: 1 addition & 1 deletion blendsql/ingredients/builtin/join/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def make_predictions(lm, left_values, right_values):
messages.append(user(current_example.to_string()))
prompt = "".join([i["content"] for i in messages])
response = (
generate(model, messages=messages)
generate(model, messages_list=[messages])[0]
.removeprefix("```json")
.removesuffix("```")
)
Expand Down
4 changes: 2 additions & 2 deletions blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __call__(
)
response = generate(
model,
messages=messages,
messages_list=[messages],
max_tokens=max_tokens,
).strip()
)[0].strip()
prompt = "".join([i["content"] for i in messages])
# Map from modified options to original, as they appear in DB
response: str = options_alias_to_original.get(response, response)
Expand Down
125 changes: 76 additions & 49 deletions blendsql/ingredients/generate.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,94 @@
from functools import singledispatch
import asyncio
from asyncio import Semaphore
import logging
from colorama import Fore
from typing import Optional, List

from .._logger import logger
from ..models import Model, OllamaLLM, OpenaiLLM, AnthropicLLM

sem = Semaphore(5)

system = lambda x: {"role": "system", "content": x}
assistant = lambda x: {"role": "assistant", "content": x}
user = lambda x: {"role": "user", "content": x}


@singledispatch
def generate(model: Model, messages: List[dict], *args, **kwargs) -> str:
def generate(model: Model, *args, **kwargs) -> str:
pass


@generate.register(OpenaiLLM)
def generate_openai(
async def run_openai_async_completions(
model: OpenaiLLM,
messages: List[dict],
messages_list: List[List[dict]],
max_tokens: Optional[int] = None,
stop_at: Optional[List[str]] = None,
**kwargs,
) -> str:
):
client: "AsyncOpenAI" = model.model_obj
async with sem:
responses = [
client.chat.completions.create(
model=model.model_name_or_path,
messages=messages,
max_tokens=max_tokens,
stop=stop_at,
**model.load_model_kwargs,
)
for messages in messages_list
]
return [m.choices[0].message.content for m in await asyncio.gather(*responses)]


@generate.register(OpenaiLLM)
def generate_openai(model: OpenaiLLM, *args, **kwargs) -> List[str]:
"""This function only exists because of a bug in guidance
https://github.com/guidance-ai/guidance/issues/881
https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a
"""
client = model.model_obj.engine.client
return (
client.chat.completions.create(
model=model.model_obj.engine.model_name,
messages=messages,
max_tokens=max_tokens,
stop=stop_at,
**model.load_model_kwargs,
)
.choices[0]
.message.content
return asyncio.get_event_loop().run_until_complete(
run_openai_async_completions(model, *args, **kwargs)
)


@generate.register(AnthropicLLM)
def generate_anthropic(
async def run_anthropic_async_completions(
model: AnthropicLLM,
messages: List[dict],
messages_list: List[List[dict]],
max_tokens: Optional[int] = None,
stop_at: Optional[List[str]] = None,
**kwargs,
):
client = model.model_obj.engine.anthropic
client: "AsyncAnthropic" = model.model_obj
async with sem:
responses = [
client.messages.create(
model=model.model_name_or_path,
messages=messages,
max_tokens=max_tokens or 4000,
# stop_sequences=stop_at
**model.load_model_kwargs,
)
for messages in messages_list
]
return [m.content[0].text for m in await asyncio.gather(*responses)]

return (
client.messages.create(
model=model.model_obj.engine.model_name,
messages=messages,
max_tokens=max_tokens or 4000,
# stop_sequences=stop_at
**model.load_model_kwargs,
)
.content[0]
.text

@generate.register(AnthropicLLM)
def generate_anthropic(
model: AnthropicLLM,
*args,
**kwargs,
) -> List[str]:
return asyncio.get_event_loop().run_until_complete(
run_anthropic_async_completions(model, *args, **kwargs)
)


@generate.register(OllamaLLM)
def generate_ollama(model: OllamaLLM, messages: List[dict], **kwargs) -> str:
def generate_ollama(model: OllamaLLM, messages_list: List[List[dict]], **kwargs) -> str:
"""Helper function to work with Ollama models,
since they're not recognized natively in the guidance ecosystem.
"""
Expand All @@ -76,7 +99,7 @@ def generate_ollama(model: OllamaLLM, messages: List[dict], **kwargs) -> str:
# )
from ollama import Options

# Turn outlines kwargs into Ollama
# Turn guidance kwargs into Ollama
if "stop_at" in kwargs:
stop_at = kwargs.pop("stop_at")
if isinstance(stop_at, str):
Expand All @@ -86,20 +109,24 @@ def generate_ollama(model: OllamaLLM, messages: List[dict], **kwargs) -> str:
if options.get("temperature") is None:
options["temperature"] = 0.0
stream = logger.level <= logging.DEBUG
response = model.model_obj(
messages=messages,
options=options,
stream=stream,
) # type: ignore
if stream:
chunked_res = []
for chunk in response:
chunked_res.append(chunk["message"]["content"])
print(
Fore.CYAN + chunk["message"]["content"] + Fore.RESET,
end="",
flush=True,
)
print("\n")
return "".join(chunked_res)
return response["message"]["content"]
responses = []
for messages in messages_list:
response = model.model_obj(
messages=messages,
options=options,
stream=stream,
) # type: ignore
if stream:
chunked_res = []
for chunk in response:
chunked_res.append(chunk["message"]["content"])
print(
Fore.CYAN + chunk["message"]["content"] + Fore.RESET,
end="",
flush=True,
)
print("\n")
responses.append("".join(chunked_res))
continue
responses.append(response["message"]["content"])
return responses
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from blendsql.models import TransformersLLM, OllamaLLM, OpenaiLLM, AnthropicLLM, Model
from blendsql import LLMQA, LLMMap, LLMJoin
from blendsql.ingredients.builtin import DEFAULT_MAP_FEW_SHOT
from blendsql.ingredients.builtin import DEFAULT_QA_FEW_SHOT

load_dotenv()

Expand Down Expand Up @@ -56,13 +55,15 @@ def pytest_generate_tests(metafunc):
ingredient_sets = [
{LLMQA, LLMMap, LLMJoin},
{
LLMQA.from_args(few_shot_examples=DEFAULT_QA_FEW_SHOT, k=1),
LLMMap.from_args(
LLMQA.from_args(
k=1,
model=TransformersLLM(
"HuggingFaceTB/SmolLM-135M-Instruct",
caching=False,
config={"chat_template": ChatMLTemplate, "device_map": "cpu"},
),
),
LLMMap.from_args(
few_shot_examples=[
*DEFAULT_MAP_FEW_SHOT,
{
Expand All @@ -73,6 +74,7 @@ def pytest_generate_tests(metafunc):
},
},
],
k=2,
batch_size=3,
),
LLMJoin.from_args(
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_llmqa(db, model, ingredients):


@pytest.mark.long
def test_llmqa_with_string(db, model, ingredients):
def test_llmmap_with_string(db, model, ingredients):
res = blend(
query="""
SELECT COUNT(*) AS "June Count" FROM w
Expand All @@ -121,7 +121,7 @@ def test_unconstrained_llmqa(db, model, ingredients):
query="""
{{
LLMQA(
"In 5 words, what's this table about?",
"What's this table about?",
(SELECT * FROM w LIMIT 1),
options='sports;food;politics'
)
Expand Down

0 comments on commit d97af1f

Please sign in to comment.