Skip to content

Commit

Permalink
Add TextGenerationWithImage task (#1066)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Martín Blázquez <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 9, 2024
1 parent 55d9e5d commit 63c75c5
Show file tree
Hide file tree
Showing 20 changed files with 629 additions and 66 deletions.
116 changes: 116 additions & 0 deletions docs/sections/pipeline_samples/examples/text_generation_with_image.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
---
hide: toc
---

# Text generation with images in `distilabel`

Answer questions about images using `distilabel`.

Image-text-to-text models take in an image and text prompt and output text. In this example we will use an LLM [`InferenceEndpointsLLM`](https://distilabel.argilla.io/dev/components-gallery/llms/inferenceendpointsllm/) with [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) to ask a question about an image, and [`OpenAILLM`](https://distilabel.argilla.io/dev/components-gallery/llms/openaillm/) with `gpt-4o-mini`. We will ask a simple question to showcase how the [`TextGenerationWithImage`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgenerationwithimage/) task can be used in a pipeline.

=== "Inference Endpoints - meta-llama/Llama-3.2-11B-Vision-Instruct"

```python
from distilabel.models.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage
from distilabel.steps import LoadDataFromDicts


with Pipeline(name="vision_generation_pipeline") as pipeline:
loader = LoadDataFromDicts(
data=[
{
"instruction": "What’s in this image?",
"image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
],
)

llm = InferenceEndpointsLLM(
model_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
)

vision = TextGenerationWithImage(
name="vision_gen",
llm=llm,
image_type="url" # (1)
)

loader >> vision
```

1. The *image_type* can be a url pointing to the image, the base64 string representation, or a PIL image, take a look at the [`TextGenerationWithImage`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgenerationwithimage/) for more information.

Image:

![Image](https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg)

Question:

> What’s in this image?

Response:

> This image depicts a wooden boardwalk weaving its way through a lush meadow, flanked by vibrant green grass that stretches towards the horizon under a calm and inviting sky. The boardwalk runs straight ahead, away from the viewer, forming a clear pathway through the tall, lush green grass, crops or other plant types or an assortment of small trees and shrubs. This meadow is dotted with trees and shrubs, appearing to be healthy and green. The sky above is a beautiful blue with white clouds scattered throughout, adding a sense of tranquility to the scene. While this image appears to be of a natural landscape, because grass is...

=== "OpenAI - gpt-4o-mini"

```python
from distilabel.models.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage
from distilabel.steps import LoadDataFromDicts


with Pipeline(name="vision_generation_pipeline") as pipeline:
loader = LoadDataFromDicts(
data=[
{
"instruction": "What’s in this image?",
"image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
],
)

llm = OpenAILLM(
model="gpt-4o-mini",
)

vision = TextGenerationWithImage(
name="vision_gen",
llm=llm,
image_type="url" # (1)
)

loader >> vision
```

1. The *image_type* can be a url pointing to the image, the base64 string representation, or a PIL image, take a look at the [`VisionGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/visiongeneration/) for more information.

Image:

![Image](https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg)

Question:

> What’s in this image?

Response:

> The image depicts a scenic landscape featuring a wooden walkway or path that runs through a lush green marsh or field. The area is surrounded by tall grass and various shrubs, with trees likely visible in the background. The sky is blue with some wispy clouds, suggesting a beautiful day. Overall, it presents a peaceful natural setting, ideal for a stroll or nature observation.


The full pipeline can be run at the following example:

??? Note "Run the full pipeline"

```python
python examples/text_generation_with_image.py
```

```python title="text_generation_with_image.py"
--8<-- "examples/text_generation_with_image.py"
```

A sample dataset can be seen at [plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct](https://huggingface.co/datasets/plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct).
7 changes: 7 additions & 0 deletions docs/sections/pipeline_samples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ hide: toc

[:octicons-arrow-right-24: Example](examples/exam_questions.md)

- __Text generation with images in distilabel__

---

Ask questions about images using distilabel.

[:octicons-arrow-right-24: Example](examples/text_generation_with_image.md)

</div>

Expand Down
41 changes: 41 additions & 0 deletions examples/text_generation_with_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.models.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage

with Pipeline(name="vision_generation_pipeline") as pipeline:
loader = LoadDataFromDicts(
data=[
{
"instruction": "What’s in this image?",
"image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
}
],
)

llm = InferenceEndpointsLLM(
model_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
)

vision = TextGenerationWithImage(name="vision_gen", llm=llm, image_type="url")

loader >> vision


if __name__ == "__main__":
distiset = pipeline.run(use_cache=False)
distiset.push_to_hub("plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct")
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ theme:
watch:
- src/distilabel

strict: true

# Extensions
markdown_extensions:
- attr_list
Expand Down Expand Up @@ -220,6 +222,7 @@ nav:
- Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md"
- Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md"
- Create questions and answers for a exam: "sections/pipeline_samples/examples/exam_questions.md"
- Text generation with images in distilabel: "sections/pipeline_samples/examples/text_generation_with_image.md"
- API Reference:
- Step:
- "api/step/index.md"
Expand Down
6 changes: 6 additions & 0 deletions src/distilabel/models/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ async def agenerate( # type: ignore
"top_p": top_p,
"stop": stop,
}
# Check if it's a vision generation task, in that case "stop" cannot be used or raises
# an error in the API.
if isinstance(
[row for row in input if row["role"] == "user"][0]["content"], list
):
kwargs.pop("stop")

if response_format is not None:
kwargs["response_format"] = response_format
Expand Down
14 changes: 12 additions & 2 deletions src/distilabel/models/llms/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

from pydantic import PrivateAttr, validate_call
from typing_extensions import TypedDict

from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.typing import GenerateOutput
Expand All @@ -27,6 +28,15 @@
from distilabel.llms.typing import LLMStatistics


class VertexChatItem(TypedDict):
role: Literal["user", "model"]
content: str


VertexChatType = List[VertexChatItem]
"""VertexChatType is a type alias for a `list` of `dict`s following the VertexAI conversational format."""


class VertexAILLM(AsyncLLM):
"""VertexAI LLM implementation running the async API clients for Gemini.
Expand Down Expand Up @@ -121,7 +131,7 @@ def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
@validate_call
async def agenerate( # type: ignore
self,
input: StandardInput,
input: VertexChatType,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand Down
7 changes: 4 additions & 3 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class _CacheLocation(TypedDict):
stages_file: Path


LoadStages = tuple[list[list[str]], list[list[str]]]


class _GlobalPipelineManager:
"""Class to manage the global pipeline instance that will be used by the steps when
created within a pipeline context.
Expand Down Expand Up @@ -445,9 +448,7 @@ def dry_run(
self._dry_run = False
return distiset

def get_load_stages(
self, load_groups: Optional["LoadGroups"] = None
) -> Tuple[List[List[str]], List[List[str]]]:
def get_load_stages(self, load_groups: Optional["LoadGroups"] = None) -> LoadStages:
"""Convenient method to get the load stages of a pipeline.
Args:
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from distilabel.steps.tasks.structured_generation import StructuredGeneration
from distilabel.steps.tasks.text_classification import TextClassification
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage
from distilabel.steps.tasks.typing import ChatItem, ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback
from distilabel.steps.tasks.urial import URIAL
Expand Down Expand Up @@ -101,4 +102,5 @@
"CLAIR",
"UltraFeedback",
"URIAL",
"TextGenerationWithImage",
]
8 changes: 7 additions & 1 deletion src/distilabel/steps/tasks/math_shepherd/completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,13 @@ def _auto_label(

return inputs

def _add_metadata(self, input, statistics, raw_output, raw_input):
def _add_metadata(
self,
input: dict[str, Any],
statistics: list["LLMStatistics"],
raw_output: Union[str, None],
raw_input: Union[list[dict[str, Any]], None],
) -> dict[str, Any]:
"""Adds the `distilabel_metadata` to the input.
This method comes for free in the general Tasks, but as we have reimplemented the `process`,
Expand Down
10 changes: 3 additions & 7 deletions src/distilabel/steps/tasks/math_shepherd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ class FormatPRM(Step):
correct steps.
Attributes:
format (Literal["math-shepherd", "trl"]): The format to use for the PRM model.
format: The format to use for the PRM model.
"math-shepherd" corresponds to the original paper, while "trl" is a format
prepared to train the model using TRL.
step_token (str): String that serves as a unique token denoting the position
step_token: String that serves as a unique token denoting the position
for predicting the step score.
tags (list[str]): List of tags that represent the correct and incorrect steps.
tags: List of tags that represent the correct and incorrect steps.
This only needs to be informed if it's different than the default in
`MathShepherdCompleter`.
Expand Down Expand Up @@ -110,10 +110,6 @@ class FormatPRM(Step):
)
)
result = next(formatter.process(result))
# result[0]["input"]
# "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. ки\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. ки\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 ки"
# result[0]["label"]
# "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"
```
Prepare your data to train a PRM model with the TRL format:
Expand Down
19 changes: 1 addition & 18 deletions src/distilabel/steps/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from jinja2 import Template
Expand All @@ -21,6 +20,7 @@
from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.base import Task
from distilabel.utils.chat import is_openai_format
from distilabel.utils.template import check_column_in_template

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
Expand Down Expand Up @@ -218,23 +218,6 @@ def model_post_init(self, __context: Any) -> None:
def load(self) -> None:
super().load()

def check_column_in_template(column, template):
pattern = (
r"(?:{%.*?\b"
+ re.escape(column)
+ r"\b.*?%}|{{\s*"
+ re.escape(column)
+ r"\s*}})"
)
if not re.search(pattern, template):
raise DistilabelUserError(
(
f"You required column name '{column}', but is not present in the template, "
"ensure the 'columns' match with the 'template' to avoid errors."
),
page="components-gallery/tasks/textgeneration/",
)

for column in self.columns:
check_column_in_template(column, self.template)

Expand Down
Loading

0 comments on commit 63c75c5

Please sign in to comment.