Skip to content

Commit

Permalink
Swapping out ImageCaption for the more versatile VQA
Browse files Browse the repository at this point in the history
Now, we can leverage constrained generation with images
  • Loading branch information
parkervg committed Jul 31, 2024
1 parent 6818562 commit 59bd49d
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 130 deletions.
2 changes: 1 addition & 1 deletion blendsql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = "0.0.21"


from .ingredients.builtin import LLMMap, LLMQA, LLMJoin, LLMValidate, ImageCaption
from .ingredients.builtin import LLMMap, LLMQA, LLMJoin, LLMValidate, VQA
from .blend import blend
14 changes: 13 additions & 1 deletion blendsql/generate/choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List
import outlines

from ..models import Model, OllamaLLM
from ..models import Model, OllamaLLM, TransformersVisionModel


@singledispatch
Expand All @@ -11,6 +11,18 @@ def choice(model: Model, prompt: str, choices: List[str], **kwargs) -> str:
return generator(prompt)


@choice.register(TransformersVisionModel)
def regex_transformers_vision(
model: TransformersVisionModel,
prompt: str,
choices: List[str],
media=None,
**kwargs
):
generator = outlines.generate.choice(model.model_obj, choices=choices)
return generator(prompt, media=media, choices=choices)


@choice.register(OllamaLLM)
def choice_ollama(*_, **__) -> str:
"""Helper function to work with Ollama models,
Expand Down
16 changes: 15 additions & 1 deletion blendsql/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List, Union
import outlines

from ..models import Model, OllamaLLM
from ..models import Model, OllamaLLM, TransformersVisionModel


@singledispatch
Expand All @@ -17,6 +17,20 @@ def regex(
return generator(prompt, max_tokens=max_tokens, stop_at=stop_at)


@regex.register(TransformersVisionModel)
def regex_transformers_vision(
model: TransformersVisionModel,
prompt: str,
regex: str,
media=None,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[List[str], str]] = None,
**kwargs
):
generator = outlines.generate.regex(model.model_obj, regex_str=regex)
return generator(prompt, media=media, max_tokens=max_tokens, stop_at=stop_at)


@regex.register(OllamaLLM)
def regex_ollama(*_, **__) -> str:
"""Helper function to work with Ollama models,
Expand Down
15 changes: 14 additions & 1 deletion blendsql/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import outlines

from .._logger import logger
from ..models import Model, OllamaLLM
from ..models import Model, OllamaLLM, TransformersVisionModel


@singledispatch
Expand All @@ -20,6 +20,19 @@ def text(
return generator(prompt, max_tokens=max_tokens, stop_at=stop_at)


@text.register(TransformersVisionModel)
def text_transformers_vision(
model: TransformersVisionModel,
prompt: str,
media=None,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[List[str], str]] = None,
**kwargs
):
generator = outlines.generate.text(model.model_obj)
return generator(prompt, media=media, max_tokens=max_tokens, stop_at=stop_at)


@text.register(OllamaLLM)
def text_ollama(model: OllamaLLM, prompt, **kwargs) -> str:
"""Helper function to work with Ollama models,
Expand Down
2 changes: 1 addition & 1 deletion blendsql/ingredients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
IngredientException,
)

from .builtin import LLMQA, LLMJoin, LLMMap, LLMValidate, ImageCaption
from .builtin import LLMQA, LLMJoin, LLMMap, LLMValidate, VQA
2 changes: 1 addition & 1 deletion blendsql/ingredients/builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .join.main import LLMJoin
from .qa.main import LLMQA
from .map.main import LLMMap
from .vqa.main import ImageCaption
from .vqa.main import VQA
from .validate.main import LLMValidate
115 changes: 24 additions & 91 deletions blendsql/ingredients/builtin/map/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import logging
from typing import Union, Iterable, Any, Dict, Optional, List, Callable, Tuple

import json
import pandas as pd
from colorama import Fore
from tqdm import tqdm
from functools import partial

from blendsql.utils import newline_dedent
from blendsql.ingredients.utils import batch_run_map
from blendsql._logger import logger
from blendsql.models import Model, LocalModel, RemoteModel, OpenaiLLM
from ast import literal_eval
from blendsql import _constants as CONST
from blendsql.ingredients.ingredient import MapIngredient
from blendsql._program import Program
Expand Down Expand Up @@ -175,91 +171,28 @@ def run(
logger.debug(f"Tablename {tablename} not in given table_to_title!")
else:
table_title = table_to_title[tablename]
split_results: List[Union[str, None]] = []
# Only use tqdm if we're in debug mode
context_manager: Iterable = (
tqdm(
range(0, len(values), CONST.MAP_BATCH_SIZE),
total=len(values) // CONST.MAP_BATCH_SIZE,
desc=f"Making calls to Model with batch_size {CONST.MAP_BATCH_SIZE}",
bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.CYAN, Fore.RESET),
)
if logger.level <= logging.DEBUG
else range(0, len(values), CONST.MAP_BATCH_SIZE)
include_tf_disclaimer = False
if output_type == "boolean":
include_tf_disclaimer = True
elif isinstance(model, OpenaiLLM):
include_tf_disclaimer = True
pred_func = partial(
model.predict,
program=MapProgram,
question=question,
sep=CONST.DEFAULT_ANS_SEP,
example_outputs=example_outputs,
output_type=output_type,
include_tf_disclaimer=include_tf_disclaimer,
table_title=table_title,
regex=regex,
**kwargs,
)

for i in context_manager:
answer_length = len(values[i : i + CONST.MAP_BATCH_SIZE])
max_tokens = answer_length * 15
include_tf_disclaimer = False

if output_type == "boolean":
include_tf_disclaimer = True
elif isinstance(model, OpenaiLLM):
include_tf_disclaimer = True

result = model.predict(
program=MapProgram,
question=question,
sep=CONST.DEFAULT_ANS_SEP,
values=values[i : i + CONST.MAP_BATCH_SIZE],
example_outputs=example_outputs,
output_type=output_type,
include_tf_disclaimer=include_tf_disclaimer,
table_title=table_title,
regex=regex,
max_tokens=max_tokens,
**kwargs,
)
# Post-process language model response
_r = [
i.strip()
for i in result.strip(CONST.DEFAULT_ANS_SEP).split(
CONST.DEFAULT_ANS_SEP
)
]
# Try to map to booleans and `None`
_r = [
{
"t": True,
"f": False,
"true": True,
"false": False,
"y": True,
"n": False,
"yes": True,
"no": False,
CONST.DEFAULT_NAN_ANS: None,
}.get(i.lower(), i)
for i in _r
]
expected_len = len(values[i : i + CONST.MAP_BATCH_SIZE])
if len(_r) != expected_len:
logger.debug(
Fore.YELLOW
+ f"Mismatch between length of values and answers!\nvalues:{expected_len}, answers:{len(_r)}"
+ Fore.RESET
)
logger.debug(_r)
# Cut off, in case we over-predicted
_r = _r[:expected_len]
# Add, in case we under-predicted
while len(_r) < expected_len:
_r.append(None)
split_results.extend(_r)
for idx, i in enumerate(split_results):
if i is None:
continue
if isinstance(i, str):
i = i.replace(",", "")
try:
split_results[idx] = literal_eval(i)
assert isinstance(i, (float, int, str))
except (ValueError, SyntaxError, AssertionError):
continue
logger.debug(
Fore.YELLOW
+ f"Finished LLMMap with values:\n{json.dumps(dict(zip(values[:10], split_results[:10])), indent=4)}"
+ Fore.RESET
split_results: List[Any] = batch_run_map(
pred_func,
values=values,
batch_size=CONST.MAP_BATCH_SIZE,
sep=CONST.DEFAULT_ANS_SEP,
nan_answer=CONST.DEFAULT_NAN_ANS,
)
return split_results
2 changes: 1 addition & 1 deletion blendsql/ingredients/builtin/vqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .main import ImageCaption
from .main import VQA
87 changes: 67 additions & 20 deletions blendsql/ingredients/builtin/vqa/main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,66 @@
from typing import List, Tuple
from typing import List, Optional, Callable, Any
from io import BytesIO
from PIL import Image
from functools import partial

from blendsql.models import Model
from blendsql.models import TransformersVisionModel, LocalModel
from blendsql.ingredients.utils import batch_run_map
from blendsql._program import Program
from blendsql.ingredients.ingredient import MapIngredient
from blendsql._exceptions import IngredientException
from blendsql import _constants as CONST
from blendsql import generate


class ImageCaptionProgram(Program):
class VQAProgram(Program):
def __call__(
self, model: Model, img_bytes: List[bytes], **kwargs
) -> Tuple[List[str], str]:
model_output = model.model_obj(
images=[Image.open(BytesIO(value)) for value in img_bytes],
# prompt=prompt,
generate_kwargs={"max_new_tokens": 200},
)
return ([output[0]["generated_text"].strip() for output in model_output], "")

self,
model: TransformersVisionModel,
question: str,
values: List[bytes],
sep: str,
max_tokens: Optional[int] = None,
regex: Optional[Callable[[int], str]] = None,
**kwargs,
):
content = [
{"type": "text", "text": question},
] + [{"type": "image"} for _ in range(len(values))]
if len(values) > 1:
content.insert(
0,
[
{
"type": "text",
"text": f"Answer the below question for each provided image, with individual answers seperated by {sep}",
}
],
)
conversation = [
{"role": "user", "content": content},
]
prompt = model.processor.apply_chat_template(conversation)
images: List[Image] = [Image.open(BytesIO(value)) for value in values]
if isinstance(model, LocalModel) and regex is not None:
response = generate.regex(
model, prompt=prompt, media=images, regex=regex(len(values))
)
else:
response = generate.text(
model, prompt=prompt, media=images, max_tokens=max_tokens
)
return (response, prompt)

class ImageCaption(MapIngredient):
DESCRIPTION = """
If we need to generate a caption for an image stored in the database, we can use the scalar function to map to a new column:
`{{ImageCaption('table::column')}}`
"""

def run(self, model: Model, values: List[bytes], **kwargs):
"""Generates a caption for all byte images passed to it."""
class VQA(MapIngredient):
def run(
self,
model: TransformersVisionModel,
question: str,
values: List[bytes],
regex: Optional[Callable[[int], str]] = None,
**kwargs,
):
if model is None:
raise IngredientException(
"ImageCaption requires a `Model` object, but nothing was passed!\nMost likely you forgot to set the `default_model` argument in `blend()`"
Expand All @@ -36,4 +69,18 @@ def run(self, model: Model, values: List[bytes], **kwargs):
raise IngredientException(
f"All values must be 'byte' type for ImageCaption!"
)
return model.predict(program=ImageCaptionProgram, img_bytes=values, **kwargs)
pred_func = partial(
model.predict,
program=VQAProgram,
question=question,
regex=regex,
**kwargs,
)
split_results: List[Any] = batch_run_map(
pred_func,
values=values,
batch_size=CONST.MAP_BATCH_SIZE,
sep=CONST.DEFAULT_ANS_SEP,
nan_answer=CONST.DEFAULT_NAN_ANS,
)
return split_results
Loading

0 comments on commit 59bd49d

Please sign in to comment.