Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated HuggingFacePipelinePromptDriver to use chat features of transformers.TextGenerationPipeline. #832

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`.

## [0.26.0] - 2024-06-04

Expand Down
37 changes: 6 additions & 31 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,62 +305,37 @@ agent.run("Write the code for a snake game.")
!!! info
This driver requires the `drivers-prompt-huggingface-pipeline` [extra](../index.md#extras).

The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks:

- text2text-generation
- text-generation
The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally.

!!! warning
Running a model locally can be a computationally expensive process.

```python
import os
from griptape.structures import Agent
from griptape.drivers import HuggingFacePipelinePromptDriver
from griptape.rules import Rule, Ruleset
from griptape.utils import PromptStack
from griptape.config import StructureConfig


# Override the default Prompt Stack to string converter
# to format the prompt in a way that is easier for this model to understand.
def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str:
prompt_lines = []

for i in prompt_stack.inputs:
if i.is_user():
prompt_lines.append(f"User: {i.content}")
elif i.is_assistant():
prompt_lines.append(f"Girafatron: {i.content}")
else:
prompt_lines.append(f"Instructions: {i.content}")
prompt_lines.append("Girafatron:")

return "\n".join(prompt_lines)


agent = Agent(
config=StructureConfig(
prompt_driver=HuggingFacePipelinePromptDriver(
model="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
prompt_stack_to_string=prompt_stack_to_string_converter,
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
)
),
rulesets=[
Ruleset(
name="Girafatron",
name="Pirate",
rules=[
Rule(
value="You are Girafatron, a giraffe-obsessed robot. You are talking to a human. "
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. "
"Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe."
value="You are a pirate chatbot who always responds in pirate speak!"
)
]
],
)
],
)

agent.run("Hello Girafatron, what is your favorite animal?")
agent.run("How many helicopters can a human eat in one sitting?")
```

### Multi Model Prompt Drivers
Expand Down
41 changes: 25 additions & 16 deletions griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from __future__ import annotations
from collections.abc import Iterator

from typing import TYPE_CHECKING
from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import PromptStack, import_optional_dependency

if TYPE_CHECKING:
from transformers import TextGenerationPipeline

Check warning on line 13 in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py#L13

Added line #L13 was not covered by tests


@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
"""
Attributes:
params: Custom model run parameters.
model: Hugging Face Hub model name.
tokenizer: Custom `HuggingFaceTokenizer`.

"""

SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
Expand All @@ -34,28 +35,36 @@
),
kw_only=True,
)
pipe: TextGenerationPipeline = field(
default=Factory(
lambda self: import_optional_dependency("transformers").pipeline(
"text-generation", model=self.model, max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
),
takes_self=True,
)
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
prompt = self.prompt_stack_to_string(prompt_stack)
pipeline = import_optional_dependency("transformers").pipeline
messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs]

generator = pipeline(
result = self.pipe(
messages,
max_new_tokens=self.max_tokens,
tokenizer=self.tokenizer.tokenizer,
model=self.model,
max_new_tokens=self.tokenizer.count_output_tokens_left(prompt),
stop_strings=self.tokenizer.stop_sequences,
temperature=self.temperature,
do_sample=True,
)

if generator.task in self.SUPPORTED_TASKS:
extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))
if isinstance(result, list):
if len(result) == 1:
generated_text = result[0]["generated_text"][-1]["content"]

if len(response) == 1:
return TextArtifact(value=response[0]["generated_text"].strip())
return TextArtifact(value=generated_text)
else:
raise Exception("completion with more than one choice is not supported yet")
else:
raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")
raise Exception("invalid output format")

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
raise NotImplementedError("streaming is not supported")
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def mock_pipeline(self, mocker):
def mock_generator(self, mock_pipeline):
mock_generator = mock_pipeline.return_value
mock_generator.task = "text-generation"
mock_generator.return_value = [{"generated_text": "model-output"}]
mock_generator.return_value = [{"generated_text": [{"content": "model-output"}]}]
return mock_generator

@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -44,6 +44,16 @@ def test_try_run(self, prompt_stack):
# Then
assert text_artifact.value == "model-output"

def test_try_stream(self, prompt_stack):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)

# When
with pytest.raises(Exception) as e:
driver.try_stream(prompt_stack)

assert e.value.args[0] == "streaming is not supported"

@pytest.mark.parametrize("choices", [[], [1, 2]])
def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack):
# Given
Expand All @@ -55,16 +65,16 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_gener
driver.try_run(prompt_stack)

# Then
e.value.args[0] == "completion with more than one choice is not supported yet"
assert e.value.args[0] == "completion with more than one choice is not supported yet"

def test_try_run_throws_when_unsupported_task_returned(self, prompt_stack, mock_generator):
def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)
mock_generator.task = "obviously-an-unsupported-task"
mock_generator.return_value = {}

# When
with pytest.raises(Exception) as e:
driver.try_run(prompt_stack)

# Then
assert e.value.args[0].startswith("only models with the following tasks are supported: ")
assert e.value.args[0] == "invalid output format"
Loading