Skip to content

Commit

Permalink
Updated HuggingFacePipelinePromptDriver to use chat features of `tr…
Browse files Browse the repository at this point in the history
…ansformers.TextGenerationPipeline`. (#832)
  • Loading branch information
collindutter authored Jun 5, 2024
1 parent a7f83b3 commit 7025db8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`.

## [0.26.0] - 2024-06-04

Expand Down
37 changes: 6 additions & 31 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,62 +305,37 @@ agent.run("Write the code for a snake game.")
!!! info
This driver requires the `drivers-prompt-huggingface-pipeline` [extra](../index.md#extras).

The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks:

- text2text-generation
- text-generation
The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally.

!!! warning
Running a model locally can be a computationally expensive process.

```python
import os
from griptape.structures import Agent
from griptape.drivers import HuggingFacePipelinePromptDriver
from griptape.rules import Rule, Ruleset
from griptape.utils import PromptStack
from griptape.config import StructureConfig


# Override the default Prompt Stack to string converter
# to format the prompt in a way that is easier for this model to understand.
def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str:
prompt_lines = []

for i in prompt_stack.inputs:
if i.is_user():
prompt_lines.append(f"User: {i.content}")
elif i.is_assistant():
prompt_lines.append(f"Girafatron: {i.content}")
else:
prompt_lines.append(f"Instructions: {i.content}")
prompt_lines.append("Girafatron:")

return "\n".join(prompt_lines)


agent = Agent(
config=StructureConfig(
prompt_driver=HuggingFacePipelinePromptDriver(
model="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
prompt_stack_to_string=prompt_stack_to_string_converter,
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
)
),
rulesets=[
Ruleset(
name="Girafatron",
name="Pirate",
rules=[
Rule(
value="You are Girafatron, a giraffe-obsessed robot. You are talking to a human. "
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. "
"Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe."
value="You are a pirate chatbot who always responds in pirate speak!"
)
]
],
)
],
)

agent.run("Hello Girafatron, what is your favorite animal?")
agent.run("How many helicopters can a human eat in one sitting?")
```

### Multi Model Prompt Drivers
Expand Down
41 changes: 25 additions & 16 deletions griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from __future__ import annotations
from collections.abc import Iterator

from typing import TYPE_CHECKING
from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import PromptStack, import_optional_dependency

if TYPE_CHECKING:
from transformers import TextGenerationPipeline


@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
"""
Attributes:
params: Custom model run parameters.
model: Hugging Face Hub model name.
tokenizer: Custom `HuggingFaceTokenizer`.
"""

SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
Expand All @@ -34,28 +35,36 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
),
kw_only=True,
)
pipe: TextGenerationPipeline = field(
default=Factory(
lambda self: import_optional_dependency("transformers").pipeline(
"text-generation", model=self.model, max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer
),
takes_self=True,
)
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
prompt = self.prompt_stack_to_string(prompt_stack)
pipeline = import_optional_dependency("transformers").pipeline
messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs]

generator = pipeline(
result = self.pipe(
messages,
max_new_tokens=self.max_tokens,
tokenizer=self.tokenizer.tokenizer,
model=self.model,
max_new_tokens=self.tokenizer.count_output_tokens_left(prompt),
stop_strings=self.tokenizer.stop_sequences,
temperature=self.temperature,
do_sample=True,
)

if generator.task in self.SUPPORTED_TASKS:
extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))
if isinstance(result, list):
if len(result) == 1:
generated_text = result[0]["generated_text"][-1]["content"]

if len(response) == 1:
return TextArtifact(value=response[0]["generated_text"].strip())
return TextArtifact(value=generated_text)
else:
raise Exception("completion with more than one choice is not supported yet")
else:
raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")
raise Exception("invalid output format")

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
raise NotImplementedError("streaming is not supported")
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def mock_pipeline(self, mocker):
def mock_generator(self, mock_pipeline):
mock_generator = mock_pipeline.return_value
mock_generator.task = "text-generation"
mock_generator.return_value = [{"generated_text": "model-output"}]
mock_generator.return_value = [{"generated_text": [{"content": "model-output"}]}]
return mock_generator

@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -44,6 +44,16 @@ def test_try_run(self, prompt_stack):
# Then
assert text_artifact.value == "model-output"

def test_try_stream(self, prompt_stack):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)

# When
with pytest.raises(Exception) as e:
driver.try_stream(prompt_stack)

assert e.value.args[0] == "streaming is not supported"

@pytest.mark.parametrize("choices", [[], [1, 2]])
def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack):
# Given
Expand All @@ -55,16 +65,16 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_gener
driver.try_run(prompt_stack)

# Then
e.value.args[0] == "completion with more than one choice is not supported yet"
assert e.value.args[0] == "completion with more than one choice is not supported yet"

def test_try_run_throws_when_unsupported_task_returned(self, prompt_stack, mock_generator):
def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)
mock_generator.task = "obviously-an-unsupported-task"
mock_generator.return_value = {}

# When
with pytest.raises(Exception) as e:
driver.try_run(prompt_stack)

# Then
assert e.value.args[0].startswith("only models with the following tasks are supported: ")
assert e.value.args[0] == "invalid output format"

0 comments on commit 7025db8

Please sign in to comment.