From cf01b46e0ede47648c110487bac964a559312f29 Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 26 Aug 2024 11:37:29 -0400 Subject: [PATCH] Moving to guidance for quicker runtimes --- blendsql/ingredients/builtin/qa/main.py | 31 +++++++++++++++++++++---- blendsql/ingredients/generate.py | 10 +++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/blendsql/ingredients/builtin/qa/main.py b/blendsql/ingredients/builtin/qa/main.py index 93b1c68..0fcb3d4 100644 --- a/blendsql/ingredients/builtin/qa/main.py +++ b/blendsql/ingredients/builtin/qa/main.py @@ -3,6 +3,7 @@ import pandas as pd import re import guidance +from colorama import Fore from blendsql.models import Model, LocalModel from blendsql.ingredients.generate import generate @@ -57,15 +58,16 @@ def __call__( m += f"\n\nContext: \n Table Description: {table_title} \n {serialized_db}" else: m += f"\n\nContext: \n {serialized_db}" + if options and not isinstance(model, LocalModel): + m += f"\n\nFor your answer, select from one of the following options: {options}" + m += "\n\nAnswer:\n" if isinstance(model, LocalModel): prompt = m._current_prompt() if options is not None: - m += guidance.capture( + response = guidance.capture( guidance.select(options=[re.escape(str(i)) for i in options]), name="result", - ) - # Map from modified options to original, as they appear in DB - response: str = options_alias_to_original.get(m["result"], m["result"]) + )["response"] else: response = guidance.capture( m + guidance.gen(max_tokens=max_tokens, stop="\n"), name="response" @@ -73,7 +75,26 @@ def __call__( else: prompt = m response = generate( - model, prompt=prompt, max_tokens=max_tokens, stop_at="\n" + model, + prompt=prompt, + options=options, + max_tokens=max( + [ + len(model.tokenizer.encode(alias)) + for alias in options_alias_to_original + ] + ) + if options + else max_tokens, + stop_at="\n", + ) + # Map from modified options to original, as they appear in DB + response: str = options_alias_to_original.get(response, response) + if options and response not in options: + print( + Fore.RED + + f"Model did not select from a valid option!\nExpected one of {options}, got {response}" + + Fore.RESET ) return (response, prompt) diff --git a/blendsql/ingredients/generate.py b/blendsql/ingredients/generate.py index 642f782..f3001c9 100644 --- a/blendsql/ingredients/generate.py +++ b/blendsql/ingredients/generate.py @@ -2,6 +2,7 @@ import logging from colorama import Fore from typing import Optional +from collections.abc import Collection from .._logger import logger from ..models import Model, OllamaLLM, OpenaiLLM @@ -30,10 +31,17 @@ def generate_openai( @generate.register(OllamaLLM) -def generate_ollama(model: OllamaLLM, prompt, **kwargs) -> str: +def generate_ollama( + model: OllamaLLM, prompt, options: Optional[Collection[str]] = None, **kwargs +) -> str: """Helper function to work with Ollama models, since they're not recognized in the Outlines ecosystem. """ + if options: + raise NotImplementedError( + "Cannot use choice generation with an Ollama model" + + "due to the limitations of the Ollama API." + ) from ollama import Options # Turn outlines kwargs into Ollama