Skip to content

Commit

Permalink
Add example of custom text generation step in quickstart (#984)
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss authored Sep 30, 2024
1 parent 3244c05 commit 3fd680c
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions docs/sections/getting_started/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pip install distilabel[hf-inference-endpoints] --upgrade

## Define a pipeline

In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task.
In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task.

> You can check the available models in the [Hugging Face Model Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) and filter by `Inference status`.
Expand All @@ -53,12 +53,14 @@ with Pipeline( # (1)
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
), # (5)
system_prompt="You are a creative AI Assistant writer.",
template="Follow the following instruction: {{ instruction }}" # (6)
)

load_dataset >> text_generation # (6)
load_dataset >> text_generation # (7)

if __name__ == "__main__":
distiset = pipeline.run( # (7)
distiset = pipeline.run( # (8)
parameters={
load_dataset.name: {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
Expand All @@ -74,7 +76,7 @@ if __name__ == "__main__":
},
},
)
distiset.push_to_hub(repo_id="distilabel-example") # (8)
distiset.push_to_hub(repo_id="distilabel-example") # (9)
```

1. We define a [`Pipeline`][distilabel.pipeline.Pipeline] with the name `simple-text-generation-pipeline` and a description `A simple text generation pipeline`. Note that the `name` is mandatory and will be used to calculate the `cache` signature path, so changing the name will change the cache path and will be identified as a different pipeline.
Expand All @@ -83,12 +85,14 @@ if __name__ == "__main__":

3. We define a [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field.

4. We define a [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.
4. We define a [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.

5. We define the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task. In this case, since the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.
5. We define the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task. In this case, since the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.

6. We connect the `load_dataset` step to the `text_generation` task using the `rshift` operator, meaning that the output from the `load_dataset` step will be used as input for the `text_generation` task.
6. Both `system_prompt` and `template` are optional fields. The `template` must be informed as a string following the [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) template format, and the fields that appear there ("instruction" in this case, which corresponds to the default) must be informed in the `columns` attribute. The component gallery for [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) has examples to get you started.

7. We run the pipeline with the parameters for the `load_dataset` and `text_generation` steps. The `load_dataset` step will use the repository `distilabel-internal-testing/instruction-dataset-mini` and the `test` split, and the `text_generation` task will use the `generation_kwargs` with the `temperature` set to `0.7` and the `max_new_tokens` set to `512`.
7. We connect the `load_dataset` step to the `text_generation` task using the `rshift` operator, meaning that the output from the `load_dataset` step will be used as input for the `text_generation` task.

8. Optionally, we can push the generated [`Distiset`][distilabel.distiset.Distiset] to the Hugging Face Hub repository `distilabel-example`. This will allow you to share the generated dataset with others and use it in other pipelines.
8. We run the pipeline with the parameters for the `load_dataset` and `text_generation` steps. The `load_dataset` step will use the repository `distilabel-internal-testing/instruction-dataset-mini` and the `test` split, and the `text_generation` task will use the `generation_kwargs` with the `temperature` set to `0.7` and the `max_new_tokens` set to `512`.

9. Optionally, we can push the generated [`Distiset`][distilabel.distiset.Distiset] to the Hugging Face Hub repository `distilabel-example`. This will allow you to share the generated dataset with others and use it in other pipelines.

0 comments on commit 3fd680c

Please sign in to comment.