Skip to content

Commit

Permalink
Fix hugging face prompt pipeline driver + docs (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes authored May 22, 2024
1 parent 9db5c8d commit b0606c7
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 9 deletions.
20 changes: 14 additions & 6 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/hugging
- text2text-generation
- text-generation

!!! warning
Not all models featured on the Hugging Face Hub are supported by this driver. Models that are not supported by
[Hugging Face serverless inference](https://huggingface.co/docs/api-inference/en/index) will not work with this driver.
Due to the limitations of Hugging Face serverless inference, only models that are than 10GB are supported.

!!! info
The `prompt_stack_to_string_converter` function is intended to convert a `PromptStack` to model specific input. You
should consult the model's documentation to determine the correct format.

Let's recreate the [Falcon-7B-Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct) example using Griptape:

```python
Expand Down Expand Up @@ -319,9 +328,9 @@ agent.run("Write the code for a snake game.")
### Hugging Face Pipeline

!!! info
This driver requires the `drivers-prompt-huggingface` [extra](../index.md#extras).
This driver requires the `drivers-prompt-huggingface-pipeline` [extra](../index.md#extras).

The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks:
The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks:

- text2text-generation
- text-generation
Expand All @@ -332,7 +341,7 @@ The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/hugging
```python
import os
from griptape.structures import Agent
from griptape.drivers import HuggingFaceHubPromptDriver
from griptape.drivers import HuggingFacePipelinePromptDriver
from griptape.rules import Rule, Ruleset
from griptape.utils import PromptStack
from griptape.config import StructureConfig
Expand All @@ -357,9 +366,8 @@ def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str:

agent = Agent(
config=StructureConfig(
prompt_driver=HuggingFaceHubPromptDriver(
model="tiiuae/falcon-7b-instruct",
api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"],
prompt_driver=HuggingFacePipelinePromptDriver(
model="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
prompt_stack_to_string=prompt_stack_to_string_converter,
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
Expand Down
Loading

0 comments on commit b0606c7

Please sign in to comment.