Skip to content

Commit

Permalink
Moving to guidance for quicker runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Aug 26, 2024
1 parent 1c706e5 commit cf01b46
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
31 changes: 26 additions & 5 deletions blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import re
import guidance
from colorama import Fore

from blendsql.models import Model, LocalModel
from blendsql.ingredients.generate import generate
Expand Down Expand Up @@ -57,23 +58,43 @@ def __call__(
m += f"\n\nContext: \n Table Description: {table_title} \n {serialized_db}"
else:
m += f"\n\nContext: \n {serialized_db}"
if options and not isinstance(model, LocalModel):
m += f"\n\nFor your answer, select from one of the following options: {options}"
m += "\n\nAnswer:\n"
if isinstance(model, LocalModel):
prompt = m._current_prompt()
if options is not None:
m += guidance.capture(
response = guidance.capture(
guidance.select(options=[re.escape(str(i)) for i in options]),
name="result",
)
# Map from modified options to original, as they appear in DB
response: str = options_alias_to_original.get(m["result"], m["result"])
)["response"]
else:
response = guidance.capture(
m + guidance.gen(max_tokens=max_tokens, stop="\n"), name="response"
)["response"]
else:
prompt = m
response = generate(
model, prompt=prompt, max_tokens=max_tokens, stop_at="\n"
model,
prompt=prompt,
options=options,
max_tokens=max(
[
len(model.tokenizer.encode(alias))
for alias in options_alias_to_original
]
)
if options
else max_tokens,
stop_at="\n",
)
# Map from modified options to original, as they appear in DB
response: str = options_alias_to_original.get(response, response)
if options and response not in options:
print(
Fore.RED
+ f"Model did not select from a valid option!\nExpected one of {options}, got {response}"
+ Fore.RESET
)
return (response, prompt)

Expand Down
10 changes: 9 additions & 1 deletion blendsql/ingredients/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from colorama import Fore
from typing import Optional
from collections.abc import Collection

from .._logger import logger
from ..models import Model, OllamaLLM, OpenaiLLM
Expand Down Expand Up @@ -30,10 +31,17 @@ def generate_openai(


@generate.register(OllamaLLM)
def generate_ollama(model: OllamaLLM, prompt, **kwargs) -> str:
def generate_ollama(
model: OllamaLLM, prompt, options: Optional[Collection[str]] = None, **kwargs
) -> str:
"""Helper function to work with Ollama models,
since they're not recognized in the Outlines ecosystem.
"""
if options:
raise NotImplementedError(
"Cannot use choice generation with an Ollama model"
+ "due to the limitations of the Ollama API."
)
from ollama import Options

# Turn outlines kwargs into Ollama
Expand Down

0 comments on commit cf01b46

Please sign in to comment.