From 3ee8077088d9283702820d276269488508baa2cb Mon Sep 17 00:00:00 2001 From: parkervg Date: Wed, 16 Oct 2024 16:51:39 -0400 Subject: [PATCH 1/9] Allow null context in LLMQA ingredient --- blendsql/ingredients/builtin/qa/examples.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/blendsql/ingredients/builtin/qa/examples.py b/blendsql/ingredients/builtin/qa/examples.py index a6dd803..03e9e15 100644 --- a/blendsql/ingredients/builtin/qa/examples.py +++ b/blendsql/ingredients/builtin/qa/examples.py @@ -1,4 +1,4 @@ -from attr import attrs, attrib, validators +from attr import attrs, attrib import pandas as pd from typing import Optional, List, Callable @@ -8,9 +8,9 @@ @attrs(kw_only=True) class QAExample(Example): question: str = attrib() - context: pd.DataFrame = attrib( + context: Optional[pd.DataFrame] = attrib( converter=lambda d: pd.DataFrame.from_dict(d) if isinstance(d, dict) else d, - validator=validators.instance_of(pd.DataFrame), + default=None, ) options: Optional[List[str]] = attrib(default=None) @@ -19,7 +19,8 @@ def to_string(self, context_formatter: Callable[[pd.DataFrame], str]) -> str: s += f"\n\nQuestion: {self.question}\n" if self.options is not None: s += f"Options: {', '.join(self.options)}\n" - s += f"Context:\n{context_formatter(self.context)}" + if self.context is not None: + s += f"Context:\n{context_formatter(self.context)}" s += "\nAnswer: " return s From 1496368c9222a87365c7eb1eeda3d3507fb85f47 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 09:38:42 -0400 Subject: [PATCH 2/9] documentation updates --- README.md | 24 ++++++++++++------------ docs/index.md | 30 ++++++++++++++++++------------ 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 68c0357..168ae1a 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,18 @@ pip install blendsql - (10/15/24) As of version 0.0.27, there is a new pattern for defining + retrieving few-shot prompts; check out [Few-Shot Prompting](#few-shot-prompting) in the README for more info - (10/15/24) Check out [Some Cool Things by Example](https://parkervg.github.io/blendsql/by-example/) for some recent language updates! +### Features + +- Supports many DBMS 💾 + - SQLite, PostgreSQL, DuckDB, Pandas (aka duckdb in a trenchcoat) +- Supports many models ✨ + - Transformers, OpenAI, Anthropic, Ollama +- Easily extendable to [multi-modal usecases](./examples/vqa-ingredient.ipynb) 🖼️ +- Write your normal queries - smart parsing optimizes what is passed to external functions 🧠 + - Traverses abstract syntax tree with [sqlglot](https://github.com/tobymao/sqlglot) to minimize LLM function calls 🌳 +- Constrained decoding with [guidance](https://github.com/guidance-ai/guidance) 🚀 + - When using local models, we only generate syntactically valid outputs according to query syntax + database contents +- LLM function caching, built on [diskcache](https://grantjenks.com/docs/diskcache/) 🔑 BlendSQL is a *superset of SQLite* for problem decomposition and hybrid question-answering with LLMs. @@ -138,18 +150,6 @@ Now, we have an intermediate representation for our LLM to use that is explainab For in-depth descriptions of the above queries, check out our [documentation](https://parkervg.github.io/blendsql/). -### Features - -- Supports many DBMS 💾 - - SQLite, PostgreSQL, DuckDB, Pandas (aka duckdb in a trenchcoat) -- Supports many models ✨ - - Transformers, OpenAI, Anthropic, Ollama -- Easily extendable to [multi-modal usecases](./examples/vqa-ingredient.ipynb) 🖼️ -- Smart parsing optimizes what is passed to external functions 🧠 - - Traverses abstract syntax tree with [sqlglot](https://github.com/tobymao/sqlglot) to minimize LLM function calls 🌳 -- Constrained decoding with [guidance](https://github.com/guidance-ai/guidance) 🚀 -- LLM function caching, built on [diskcache](https://grantjenks.com/docs/diskcache/) 🔑 - ## Quickstart ```python diff --git a/docs/index.md b/docs/index.md index a472870..74e4d33 100644 --- a/docs/index.md +++ b/docs/index.md @@ -25,6 +25,24 @@ pip install blendsql ``` + +### ✨ News +- (10/15/24) As of version 0.0.27, there is a new pattern for defining + retrieving few-shot prompts; check out [Few-Shot Prompting](#few-shot-prompting) in the README for more info +- (10/15/24) Check out [Some Cool Things by Example](https://parkervg.github.io/blendsql/by-example/) for some recent language updates! + +### Features + +- Supports many DBMS 💾 + - SQLite, PostgreSQL, DuckDB, Pandas (aka duckdb in a trenchcoat) +- Supports many models ✨ + - Transformers, OpenAI, Anthropic, Ollama +- Easily extendable to [multi-modal usecases](./examples/vqa-ingredient.ipynb) 🖼️ +- Write your normal queries - smart parsing optimizes what is passed to external functions 🧠 + - Traverses abstract syntax tree with [sqlglot](https://github.com/tobymao/sqlglot) to minimize LLM function calls 🌳 +- Constrained decoding with [guidance](https://github.com/guidance-ai/guidance) 🚀 + - When using local models, we only generate syntactically valid outputs according to query syntax + database contents +- LLM function caching, built on [diskcache](https://grantjenks.com/docs/diskcache/) 🔑 +- BlendSQL is a *superset of SQLite* for problem decomposition and hybrid question-answering with LLMs. As a result, we can *Blend* together... @@ -125,18 +143,6 @@ Now, we have an intermediate representation for our LLM to use that is explainab For in-depth descriptions of the above queries, check out our [documentation](https://parkervg.github.io/blendsql/). -### Features - -- Supports many DBMS 💾 - - SQLite, PostgreSQL, DuckDB, Pandas (aka duckdb in a trenchcoat) -- Supports many models ✨ - - Transformers, OpenAI, Anthropic, Ollama -- Easily extendable to [multi-modal usecases](./examples/vqa-ingredient.ipynb) 🖼️ -- Smart parsing optimizes what is passed to external functions 🧠 - - Traverses abstract syntax tree with [sqlglot](https://github.com/tobymao/sqlglot) to minimize LLM function calls 🌳 -- Constrained decoding with [guidance](https://github.com/guidance-ai/guidance) 🚀 -- LLM function caching, built on [diskcache](https://grantjenks.com/docs/diskcache/) 🔑 -
### Citation From b516a92dcd9c88f516ee18d5a1f59d33e59f9e18 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 09:45:41 -0400 Subject: [PATCH 3/9] Remove coverage badge action --- .github/workflows/tests.yml | 92 ++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6391692..ae0d2c3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,50 +42,50 @@ jobs: run: | python -m tox - - name: "Upload coverage data" - uses: actions/upload-artifact@v4 - with: - name: covdata - path: .coverage.* - - coverage: - name: Coverage - needs: tests - runs-on: ubuntu-latest - steps: - - name: "Check out the repo" - uses: "actions/checkout@v2" - - - name: "Set up Python" - uses: "actions/setup-python@v2" - with: - python-version: "3.12" +# - name: "Upload coverage data" +# uses: actions/upload-artifact@v4 +# with: +# name: covdata +# path: .coverage.* - - name: "Install dependencies" - run: | - python -m pip install tox tox-gh-actions - - - name: "Download coverage data" - uses: actions/download-artifact@v4 - with: - name: covdata - - - name: "Combine" - run: | - python -m tox -e coverage - export TOTAL=$(python -c "import json;print(json.load(open('coverage.json'))['totals']['percent_covered_display'])") - echo "total=$TOTAL" >> $GITHUB_ENV - echo "### Total coverage: ${TOTAL}%" >> $GITHUB_STEP_SUMMARY - - - name: "Make badge" - uses: schneegans/dynamic-badges-action@v1.4.0 - with: - # GIST_TOKEN is a GitHub personal access token with scope "gist". - auth: ${{ secrets.GIST_TOKEN }} - gistID: e24f1214fdff3ab086b829b5f01f85a8 # replace with your real Gist id. - filename: covbadge.json - label: Coverage - message: ${{ env.total }}% - minColorRange: 50 - maxColorRange: 90 - valColorRange: ${{ env.total }} \ No newline at end of file +# coverage: +# name: Coverage +# needs: tests +# runs-on: ubuntu-latest +# steps: +# - name: "Check out the repo" +# uses: "actions/checkout@v2" +# +# - name: "Set up Python" +# uses: "actions/setup-python@v2" +# with: +# python-version: "3.12" +# +# - name: "Install dependencies" +# run: | +# python -m pip install tox tox-gh-actions +# +# - name: "Download coverage data" +# uses: actions/download-artifact@v4 +# with: +# name: covdata +# +# - name: "Combine" +# run: | +# python -m tox -e coverage +# export TOTAL=$(python -c "import json;print(json.load(open('coverage.json'))['totals']['percent_covered_display'])") +# echo "total=$TOTAL" >> $GITHUB_ENV +# echo "### Total coverage: ${TOTAL}%" >> $GITHUB_STEP_SUMMARY +# +# - name: "Make badge" +# uses: schneegans/dynamic-badges-action@v1.4.0 +# with: +# # GIST_TOKEN is a GitHub personal access token with scope "gist". +# auth: ${{ secrets.GIST_TOKEN }} +# gistID: e24f1214fdff3ab086b829b5f01f85a8 # replace with your real Gist id. +# filename: covbadge.json +# label: Coverage +# message: ${{ env.total }}% +# minColorRange: 50 +# maxColorRange: 90 +# valColorRange: ${{ env.total }} \ No newline at end of file From 70c127164dfc72a0559fce36a5e1e165d59c7af0 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 17:47:19 -0400 Subject: [PATCH 4/9] Switching out guidance Anthropic/OpenAI objects for the clients themselves --- blendsql/models/remote/_anthropic.py | 6 ++---- blendsql/models/remote/_openai.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/blendsql/models/remote/_anthropic.py b/blendsql/models/remote/_anthropic.py index 3938b65..42ab2b9 100644 --- a/blendsql/models/remote/_anthropic.py +++ b/blendsql/models/remote/_anthropic.py @@ -63,8 +63,6 @@ def __init__( ) def _load_model(self) -> ModelObj: - from guidance.models import Anthropic + from anthropic import AsyncAnthropic - return Anthropic( - self.model_name_or_path, echo=False, api_key=os.getenv("ANTHROPIC_API_KEY") - ) + return AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) diff --git a/blendsql/models/remote/_openai.py b/blendsql/models/remote/_openai.py index 9b29ae9..d2a7bd3 100644 --- a/blendsql/models/remote/_openai.py +++ b/blendsql/models/remote/_openai.py @@ -169,11 +169,9 @@ def __init__( ) def _load_model(self) -> ModelObj: - from guidance.models import OpenAI + from openai import AsyncClient - return OpenAI( - self.model_name_or_path, echo=False, api_key=os.getenv("OPENAI_API_KEY") - ) + return AsyncClient(api_key=os.getenv("OPENAI_API_KEY")) def _setup(self, **kwargs) -> None: openai_setup() From 5e1037b3d64f8fa6131faabb1b5da51a6517562d Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 17:48:54 -0400 Subject: [PATCH 5/9] Put batching logic into `MapProgram` To do this, we initialize the `MapExample` default values to `None`, and modify this attribute with the batched values before calling `to_string()` --- blendsql/ingredients/builtin/map/examples.py | 2 +- blendsql/ingredients/builtin/map/main.py | 233 +++++++++++-------- 2 files changed, 131 insertions(+), 104 deletions(-) diff --git a/blendsql/ingredients/builtin/map/examples.py b/blendsql/ingredients/builtin/map/examples.py index 2c5d1b6..c15c71c 100644 --- a/blendsql/ingredients/builtin/map/examples.py +++ b/blendsql/ingredients/builtin/map/examples.py @@ -45,7 +45,7 @@ def to_string(self, include_values: bool = True) -> str: @attrs(kw_only=True) class MapExample(_MapExample): - values: List[str] = attrib() + values: List[str] = attrib(default=None) @attrs(kw_only=True) diff --git a/blendsql/ingredients/builtin/map/main.py b/blendsql/ingredients/builtin/map/main.py index ea6d1a1..07b29c7 100644 --- a/blendsql/ingredients/builtin/map/main.py +++ b/blendsql/ingredients/builtin/map/main.py @@ -1,3 +1,4 @@ +import copy import logging from typing import Union, Iterable, Any, Dict, Optional, List, Callable, Tuple from pathlib import Path @@ -35,13 +36,27 @@ def __call__( self, model: Model, current_example: MapExample, + values: List[str], few_shot_examples: List[AnnotatedMapExample], + batch_size: int, list_options_in_prompt: bool = True, max_tokens: Optional[int] = None, regex: Optional[str] = None, **kwargs, ) -> Tuple[str, str]: + # Only use tqdm if we're in debug mode + context_manager: Iterable = ( + tqdm( + range(0, len(values), batch_size), + total=len(values) // batch_size, + desc=f"Making calls to Model with batch_size {batch_size}", + bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.CYAN, Fore.RESET), + ) + if logger.level <= logging.DEBUG + else range(0, len(values), batch_size) + ) if isinstance(model, LocalModel): + prompts = [] options = current_example.options if all(x is not None for x in [options, regex]): raise IngredientException( @@ -58,8 +73,7 @@ def __call__( for k, v in example.mapping.items(): lm += f"\n{k} -> {v}" lm += "\n\n---" - lm += current_example.to_string(include_values=False) - prompt = lm._current_prompt() + if isinstance(model, LocalModel) and regex is not None: gen_f = lambda: guidance.regex(pattern=regex) else: @@ -74,30 +88,85 @@ def make_predictions(lm, values, gen_f) -> guidance.models.Model: lm += guidance.capture(gen_f(), name=value) return lm - lm += make_predictions(values=current_example.values, gen_f=gen_f) - mapped_values = [lm[value] for value in current_example.values] - else: - # Use the 'old' style of prompting when we have a remote model - messages = [] - messages.append(user(MAIN_INSTRUCTION)) - # Add few-shot examples - for example in few_shot_examples: - messages.append(user(example.to_string())) - messages.append( - assistant(CONST.DEFAULT_ANS_SEP.join(example.mapping.values())) + mapped_values: List[str] = [] + for i in context_manager: + curr_batch_values = values[i : i + batch_size] + current_batch_example = copy.deepcopy(current_example) + current_batch_example.values = curr_batch_values + with guidance.user(): + batch_lm = lm + current_example.to_string(include_values=False) + prompts.append(batch_lm._current_prompt()) + with guidance.assistant(): + batch_lm += make_predictions( + values=current_batch_example.values, gen_f=gen_f + ) + mapped_values.extend( + [batch_lm[value] for value in current_batch_example.values] ) - # Add the current question + context for inference - messages.append(user(current_example.to_string())) - response = generate(model, messages=messages, max_tokens=max_tokens or 1000) + else: + messages_list: List[List[dict]] = [] + batch_sizes: List[int] = [] + for i in context_manager: + messages = [] + curr_batch_values = values[i : i + batch_size] + batch_sizes.append(len(curr_batch_values)) + current_batch_example = copy.deepcopy(current_example) + current_batch_example.values = curr_batch_values + messages.append(user(MAIN_INSTRUCTION)) + # Add few-shot examples + for example in few_shot_examples: + messages.append(user(example.to_string())) + messages.append( + assistant(CONST.DEFAULT_ANS_SEP.join(example.mapping.values())) + ) + # Add the current question + context for inference + messages.append(user(current_batch_example.to_string())) + messages_list.append(messages) + + responses: List[str] = generate( + model, messages_list=messages_list, max_tokens=max_tokens or 1000 + ) + # Post-process language model response - mapped_values = [ - i.strip() - for i in response.strip(CONST.DEFAULT_ANS_SEP).split( - CONST.DEFAULT_ANS_SEP - ) + mapped_values: List[str] = [] + for idx, r in enumerate(responses): + expected_len = batch_sizes[idx] + predictions = r.split(CONST.DEFAULT_ANS_SEP) + while len(predictions) < expected_len: + predictions.append(None) + mapped_values.extend(predictions) + prompts = [ + "".join([i["content"] for i in messages]) for messages in messages_list ] - prompt = "".join([i["content"] for i in messages]) - return mapped_values, prompt + # Try to map to booleans and `None` + mapped_values = [ + { + "t": True, + "f": False, + "true": True, + "false": False, + "y": True, + "n": False, + "yes": True, + "no": False, + CONST.DEFAULT_NAN_ANS: None, + }.get(i.lower(), i) + if isinstance(i, str) + else i + for i in mapped_values + ] + # Try to cast strings as numerics + for idx, value in enumerate(mapped_values): + if not isinstance(value, str): + continue + value = value.replace(",", "") + try: + casted_value = literal_eval(value) + assert isinstance(casted_value, (float, int, str)) + mapped_values[idx] = casted_value + except (ValueError, SyntaxError, AssertionError): + continue + return mapped_values, prompts @attrs @@ -238,88 +307,46 @@ def run( if value_limit is not None: values = values[:value_limit] values = [value if not pd.isna(value) else "-" for value in values] - split_results: List[Union[str, None]] = [] - # Only use tqdm if we're in debug mode - context_manager: Iterable = ( - tqdm( - range(0, len(values), batch_size), - total=len(values) // batch_size, - desc=f"Making calls to Model with batch_size {batch_size}", - bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.CYAN, Fore.RESET), - ) - if logger.level <= logging.DEBUG - else range(0, len(values), batch_size) - ) - for i in context_manager: - answer_length = len(values[i : i + batch_size]) - max_tokens = answer_length * 15 - curr_batch_values = values[i : i + batch_size] - current_example = MapExample( - **{ - "question": question, - "column_name": column_name, - "table_name": table_name, - "output_type": output_type, - "example_outputs": example_outputs, - "values": curr_batch_values, - } - ) - few_shot_examples: List[AnnotatedMapExample] = few_shot_retriever( - current_example.to_string() - ) - mapped_values: List[str] = model.predict( - program=MapProgram, - current_example=current_example, - question=question, - few_shot_examples=few_shot_examples, - options=options, - list_options_in_prompt=list_options_in_prompt, - example_outputs=example_outputs, - output_type=output_type, - table_name=table_name, - column_name=column_name, - regex=regex, - max_tokens=max_tokens, - **kwargs, - ) - # Try to map to booleans and `None` - mapped_values = [ - { - "t": True, - "f": False, - "true": True, - "false": False, - "y": True, - "n": False, - "yes": True, - "no": False, - CONST.DEFAULT_NAN_ANS: None, - }.get(i.lower(), i) - for i in mapped_values - ] - expected_len = len(curr_batch_values) - if len(mapped_values) != expected_len: - logger.debug( - Fore.YELLOW - + f"Mismatch between length of values and answers!\nvalues:{expected_len}, answers:{len(mapped_values)}" - + Fore.RESET - ) - logger.debug(mapped_values) - split_results.extend(mapped_values) - for idx, i in enumerate(split_results): - if i is None: - continue - if isinstance(i, str): - i = i.replace(",", "") - try: - split_results[idx] = literal_eval(i) - assert isinstance(i, (float, int, str)) - except (ValueError, SyntaxError, AssertionError): - continue + # for i in context_manager: + # answer_length = len(values[i : i + batch_size]) + # max_tokens = answer_length * 15 + # curr_batch_values = values[i : i + batch_size] + current_example = MapExample( + **{ + "question": question, + "column_name": column_name, + "table_name": table_name, + "output_type": output_type, + "example_outputs": example_outputs, + # Random subset of values for few-shot example retrieval + # these will get replaced during batching later + "values": values[:10], + } + ) + few_shot_examples: List[AnnotatedMapExample] = few_shot_retriever( + current_example.to_string() + ) + mapped_values: List[str] = model.predict( + program=MapProgram, + current_example=current_example, + values=values, + question=question, + few_shot_examples=few_shot_examples, + batch_size=batch_size, + options=options, + list_options_in_prompt=list_options_in_prompt, + example_outputs=example_outputs, + output_type=output_type, + table_name=table_name, + column_name=column_name, + regex=regex, + # max_tokens=max_tokens, + **kwargs, + ) logger.debug( Fore.YELLOW - + f"Finished LLMMap with values:\n{json.dumps(dict(zip(values[:10], split_results[:10])), indent=4)}" + + f"Finished LLMMap with values:\n{json.dumps(dict(zip(values[:10], mapped_values[:10])), indent=4)}" + Fore.RESET ) - return split_results + return mapped_values From 6b05ba85cf218894b37c97877e3d0c74910f98a1 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 17:50:08 -0400 Subject: [PATCH 6/9] Fix caching bug Previously, the `Database` object had a reference to a connection in memory, which incorrectly invalidated the cache across runs --- blendsql/models/_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index 5f3c3be..f62201a 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -16,6 +16,7 @@ from .._program import Program, program_to_str from .._constants import IngredientKwarg from ..db.utils import truncate_df_content +from ..db._database import Database CONTEXT_TRUNCATION_LIMIT = 100 ModelObj = TypeVar("ModelObj") @@ -147,9 +148,11 @@ def _create_key(self, program: Type[Program], **kwargs) -> str: options_str = str( sorted( [ - (k, sorted(v) if isinstance(v, set) else v) + (k, str(sorted(v) if isinstance(v, set) else v)) for k, v in kwargs.items() if not callable(v) + and not isinstance(v, Database) + and "uuid" not in k ] ) ) From 0f4f51b83b521d32ad98b640e8adc6ab39290144 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 17:50:29 -0400 Subject: [PATCH 7/9] Modify `model.predict()` to account for async batching --- blendsql/models/_model.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index f62201a..5ba11e2 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -120,17 +120,20 @@ def predict(self, program: Type[Program], **kwargs) -> str: self.raw_prompts.insert(-1, "") return response # Modify fields used for tracking Model usage - response: Union[str, List[str]] - prompt: str - response, prompt = program(model=self, **kwargs) - self.prompts.insert(-1, self.format_prompt(response, **kwargs)) - self.raw_prompts.insert(-1, prompt) - self.num_calls += 1 - if self.tokenizer is not None: - self.prompt_tokens += len(self.tokenizer.encode(prompt)) - self.completion_tokens += sum( - [len(self.tokenizer.encode(r)) for r in " ".join(response)] - ) + response: Any + prompts: Union[str, List[str]] + response, prompts = program(model=self, **kwargs) + if not isinstance(prompts, list): + prompts = [prompts] + for prompt in prompts: + self.prompts.insert(-1, self.format_prompt(response, **kwargs)) + self.raw_prompts.insert(-1, prompt) + self.num_calls += 1 + if self.tokenizer is not None: + self.prompt_tokens += len(self.tokenizer.encode(prompt)) + # self.completion_tokens += sum( + # [len(self.tokenizer.encode(r)) for r in " ".join(response)] + # ) if self.caching: self.cache[key] = response # type: ignore return response From d97af1f24420feddd38d6c59ab5dc38dd97d215b Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 17:51:01 -0400 Subject: [PATCH 8/9] Make OpenAI, Antrhopic calls async by default --- blendsql/db/_database.py | 6 ++ blendsql/ingredients/builtin/join/main.py | 2 +- blendsql/ingredients/builtin/qa/main.py | 4 +- blendsql/ingredients/generate.py | 125 +++++++++++++--------- tests/conftest.py | 8 +- tests/models/test_models.py | 4 +- 6 files changed, 92 insertions(+), 57 deletions(-) diff --git a/blendsql/db/_database.py b/blendsql/db/_database.py index 501dc82..77230e2 100644 --- a/blendsql/db/_database.py +++ b/blendsql/db/_database.py @@ -12,6 +12,12 @@ class Database(ABC): db_url: Union[URL, str] = attrib() lazy_tables: LazyTables = LazyTables() + def __str__(self): + return f"{self.__class__} @ {self.db_url}" + + def __repr__(self): + return f"{self.__class__} @ {self.db_url}" + @abstractmethod def _reset_connection(self) -> None: """Reset connection, so that temp tables are cleared.""" diff --git a/blendsql/ingredients/builtin/join/main.py b/blendsql/ingredients/builtin/join/main.py index b5ab185..049028d 100644 --- a/blendsql/ingredients/builtin/join/main.py +++ b/blendsql/ingredients/builtin/join/main.py @@ -79,7 +79,7 @@ def make_predictions(lm, left_values, right_values): messages.append(user(current_example.to_string())) prompt = "".join([i["content"] for i in messages]) response = ( - generate(model, messages=messages) + generate(model, messages_list=[messages])[0] .removeprefix("```json") .removesuffix("```") ) diff --git a/blendsql/ingredients/builtin/qa/main.py b/blendsql/ingredients/builtin/qa/main.py index a32879f..fb005bf 100644 --- a/blendsql/ingredients/builtin/qa/main.py +++ b/blendsql/ingredients/builtin/qa/main.py @@ -121,9 +121,9 @@ def __call__( ) response = generate( model, - messages=messages, + messages_list=[messages], max_tokens=max_tokens, - ).strip() + )[0].strip() prompt = "".join([i["content"] for i in messages]) # Map from modified options to original, as they appear in DB response: str = options_alias_to_original.get(response, response) diff --git a/blendsql/ingredients/generate.py b/blendsql/ingredients/generate.py index d3868b2..ada5b33 100644 --- a/blendsql/ingredients/generate.py +++ b/blendsql/ingredients/generate.py @@ -1,4 +1,6 @@ from functools import singledispatch +import asyncio +from asyncio import Semaphore import logging from colorama import Fore from typing import Optional, List @@ -6,66 +8,87 @@ from .._logger import logger from ..models import Model, OllamaLLM, OpenaiLLM, AnthropicLLM +sem = Semaphore(5) + system = lambda x: {"role": "system", "content": x} assistant = lambda x: {"role": "assistant", "content": x} user = lambda x: {"role": "user", "content": x} @singledispatch -def generate(model: Model, messages: List[dict], *args, **kwargs) -> str: +def generate(model: Model, *args, **kwargs) -> str: pass -@generate.register(OpenaiLLM) -def generate_openai( +async def run_openai_async_completions( model: OpenaiLLM, - messages: List[dict], + messages_list: List[List[dict]], max_tokens: Optional[int] = None, stop_at: Optional[List[str]] = None, **kwargs, -) -> str: +): + client: "AsyncOpenAI" = model.model_obj + async with sem: + responses = [ + client.chat.completions.create( + model=model.model_name_or_path, + messages=messages, + max_tokens=max_tokens, + stop=stop_at, + **model.load_model_kwargs, + ) + for messages in messages_list + ] + return [m.choices[0].message.content for m in await asyncio.gather(*responses)] + + +@generate.register(OpenaiLLM) +def generate_openai(model: OpenaiLLM, *args, **kwargs) -> List[str]: """This function only exists because of a bug in guidance https://github.com/guidance-ai/guidance/issues/881 + + https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a """ - client = model.model_obj.engine.client - return ( - client.chat.completions.create( - model=model.model_obj.engine.model_name, - messages=messages, - max_tokens=max_tokens, - stop=stop_at, - **model.load_model_kwargs, - ) - .choices[0] - .message.content + return asyncio.get_event_loop().run_until_complete( + run_openai_async_completions(model, *args, **kwargs) ) -@generate.register(AnthropicLLM) -def generate_anthropic( +async def run_anthropic_async_completions( model: AnthropicLLM, - messages: List[dict], + messages_list: List[List[dict]], max_tokens: Optional[int] = None, stop_at: Optional[List[str]] = None, **kwargs, ): - client = model.model_obj.engine.anthropic + client: "AsyncAnthropic" = model.model_obj + async with sem: + responses = [ + client.messages.create( + model=model.model_name_or_path, + messages=messages, + max_tokens=max_tokens or 4000, + # stop_sequences=stop_at + **model.load_model_kwargs, + ) + for messages in messages_list + ] + return [m.content[0].text for m in await asyncio.gather(*responses)] - return ( - client.messages.create( - model=model.model_obj.engine.model_name, - messages=messages, - max_tokens=max_tokens or 4000, - # stop_sequences=stop_at - **model.load_model_kwargs, - ) - .content[0] - .text + +@generate.register(AnthropicLLM) +def generate_anthropic( + model: AnthropicLLM, + *args, + **kwargs, +) -> List[str]: + return asyncio.get_event_loop().run_until_complete( + run_anthropic_async_completions(model, *args, **kwargs) ) @generate.register(OllamaLLM) -def generate_ollama(model: OllamaLLM, messages: List[dict], **kwargs) -> str: +def generate_ollama(model: OllamaLLM, messages_list: List[List[dict]], **kwargs) -> str: """Helper function to work with Ollama models, since they're not recognized natively in the guidance ecosystem. """ @@ -76,7 +99,7 @@ def generate_ollama(model: OllamaLLM, messages: List[dict], **kwargs) -> str: # ) from ollama import Options - # Turn outlines kwargs into Ollama + # Turn guidance kwargs into Ollama if "stop_at" in kwargs: stop_at = kwargs.pop("stop_at") if isinstance(stop_at, str): @@ -86,20 +109,24 @@ def generate_ollama(model: OllamaLLM, messages: List[dict], **kwargs) -> str: if options.get("temperature") is None: options["temperature"] = 0.0 stream = logger.level <= logging.DEBUG - response = model.model_obj( - messages=messages, - options=options, - stream=stream, - ) # type: ignore - if stream: - chunked_res = [] - for chunk in response: - chunked_res.append(chunk["message"]["content"]) - print( - Fore.CYAN + chunk["message"]["content"] + Fore.RESET, - end="", - flush=True, - ) - print("\n") - return "".join(chunked_res) - return response["message"]["content"] + responses = [] + for messages in messages_list: + response = model.model_obj( + messages=messages, + options=options, + stream=stream, + ) # type: ignore + if stream: + chunked_res = [] + for chunk in response: + chunked_res.append(chunk["message"]["content"]) + print( + Fore.CYAN + chunk["message"]["content"] + Fore.RESET, + end="", + flush=True, + ) + print("\n") + responses.append("".join(chunked_res)) + continue + responses.append(response["message"]["content"]) + return responses diff --git a/tests/conftest.py b/tests/conftest.py index d31d90b..9ed5a82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,6 @@ from blendsql.models import TransformersLLM, OllamaLLM, OpenaiLLM, AnthropicLLM, Model from blendsql import LLMQA, LLMMap, LLMJoin from blendsql.ingredients.builtin import DEFAULT_MAP_FEW_SHOT -from blendsql.ingredients.builtin import DEFAULT_QA_FEW_SHOT load_dotenv() @@ -56,13 +55,15 @@ def pytest_generate_tests(metafunc): ingredient_sets = [ {LLMQA, LLMMap, LLMJoin}, { - LLMQA.from_args(few_shot_examples=DEFAULT_QA_FEW_SHOT, k=1), - LLMMap.from_args( + LLMQA.from_args( + k=1, model=TransformersLLM( "HuggingFaceTB/SmolLM-135M-Instruct", caching=False, config={"chat_template": ChatMLTemplate, "device_map": "cpu"}, ), + ), + LLMMap.from_args( few_shot_examples=[ *DEFAULT_MAP_FEW_SHOT, { @@ -73,6 +74,7 @@ def pytest_generate_tests(metafunc): }, }, ], + k=2, batch_size=3, ), LLMJoin.from_args( diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 34a2fdf..ee3e8fb 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -95,7 +95,7 @@ def test_llmqa(db, model, ingredients): @pytest.mark.long -def test_llmqa_with_string(db, model, ingredients): +def test_llmmap_with_string(db, model, ingredients): res = blend( query=""" SELECT COUNT(*) AS "June Count" FROM w @@ -121,7 +121,7 @@ def test_unconstrained_llmqa(db, model, ingredients): query=""" {{ LLMQA( - "In 5 words, what's this table about?", + "What's this table about?", (SELECT * FROM w LIMIT 1), options='sports;food;politics' ) From e858f8921a46cca4d883cf8d85f1b4e5b467c693 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 17 Oct 2024 20:50:13 -0400 Subject: [PATCH 9/9] Logging missing values on LLMMap --- blendsql/ingredients/builtin/map/main.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/blendsql/ingredients/builtin/map/main.py b/blendsql/ingredients/builtin/map/main.py index 07b29c7..32d6562 100644 --- a/blendsql/ingredients/builtin/map/main.py +++ b/blendsql/ingredients/builtin/map/main.py @@ -129,12 +129,19 @@ def make_predictions(lm, values, gen_f) -> guidance.models.Model: # Post-process language model response mapped_values: List[str] = [] + total_missing_values = 0 for idx, r in enumerate(responses): expected_len = batch_sizes[idx] predictions = r.split(CONST.DEFAULT_ANS_SEP) while len(predictions) < expected_len: + total_missing_values += 1 predictions.append(None) mapped_values.extend(predictions) + if total_missing_values > 0: + logger.debug( + Fore.RED + + f"LLMMap with {type(model).__name__}({model.model_name_or_path}) only returned {len(mapped_values)-total_missing_values} out of {len(mapped_values)} values" + ) prompts = [ "".join([i["content"] for i in messages]) for messages in messages_list ]