diff --git a/CHANGELOG.md b/CHANGELOG.md index 87de56beb0..ef2a4aecb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Changed +- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`. + ## [0.26.0] - 2024-06-04 ### Added diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 1006f215fa..bcd24fc74f 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -330,62 +330,37 @@ agent.run("Write the code for a snake game.") !!! info This driver requires the `drivers-prompt-huggingface-pipeline` [extra](../index.md#extras). -The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks: - -- text2text-generation -- text-generation +The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. !!! warning Running a model locally can be a computationally expensive process. ```python -import os from griptape.structures import Agent from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset -from griptape.utils import PromptStack from griptape.config import StructureConfig -# Override the default Prompt Stack to string converter -# to format the prompt in a way that is easier for this model to understand. -def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str: - prompt_lines = [] - - for i in prompt_stack.inputs: - if i.is_user(): - prompt_lines.append(f"User: {i.content}") - elif i.is_assistant(): - prompt_lines.append(f"Girafatron: {i.content}") - else: - prompt_lines.append(f"Instructions: {i.content}") - prompt_lines.append("Girafatron:") - - return "\n".join(prompt_lines) - - agent = Agent( config=StructureConfig( prompt_driver=HuggingFacePipelinePromptDriver( - model="TinyLlama/TinyLlama-1.1B-Chat-v0.6", - prompt_stack_to_string=prompt_stack_to_string_converter, + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) ), rulesets=[ Ruleset( - name="Girafatron", + name="Pirate", rules=[ Rule( - value="You are Girafatron, a giraffe-obsessed robot. You are talking to a human. " - "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " - "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." + value="You are a pirate chatbot who always responds in pirate speak!" ) - ] + ], ) ], ) -agent.run("Hello Girafatron, what is your favorite animal?") +agent.run("How many helicopters can a human eat in one sitting?") ``` ### Multi Model Prompt Drivers diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index bde6d5e4e2..82e58d7bd9 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -1,5 +1,7 @@ +from __future__ import annotations from collections.abc import Iterator +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.artifacts import TextArtifact @@ -7,6 +9,9 @@ from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import PromptStack, import_optional_dependency +if TYPE_CHECKING: + from transformers import TextGenerationPipeline + @define class HuggingFacePipelinePromptDriver(BasePromptDriver): @@ -14,13 +19,9 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): Attributes: params: Custom model run parameters. model: Hugging Face Hub model name. - tokenizer: Custom `HuggingFaceTokenizer`. """ - SUPPORTED_TASKS = ["text2text-generation", "text-generation"] - DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1} - max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) @@ -34,28 +35,38 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) + pipe: TextGenerationPipeline = field( + default=Factory( + lambda self: import_optional_dependency("transformers").pipeline( + "text-generation", model=self.model, max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer + ), + takes_self=True, + ) + ) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - prompt = self.prompt_stack_to_string(prompt_stack) - pipeline = import_optional_dependency("transformers").pipeline + messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs] - generator = pipeline( - tokenizer=self.tokenizer.tokenizer, - model=self.model, - max_new_tokens=self.tokenizer.count_output_tokens_left(prompt), + result = self.pipe( + messages, + max_new_tokens=self.max_tokens, + eos_token_id=[ + self.tokenizer.tokenizer.eos_token_id, + *[self.pipe.tokenizer.convert_tokens_to_ids(token) for token in self.tokenizer.stop_sequences], + ], + temperature=self.temperature, + do_sample=True, ) - if generator.task in self.SUPPORTED_TASKS: - extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id} - - response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params)) + if isinstance(result, list): + if len(result) == 1: + generated_text = result[0]["generated_text"][-1]["content"] - if len(response) == 1: - return TextArtifact(value=response[0]["generated_text"].strip()) + return TextArtifact(value=generated_text) else: raise Exception("completion with more than one choice is not supported yet") else: - raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}") + raise Exception("invalid output format") def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: raise NotImplementedError("streaming is not supported") diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index fec39da4db..b565ccf4d9 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -13,7 +13,7 @@ def mock_pipeline(self, mocker): def mock_generator(self, mock_pipeline): mock_generator = mock_pipeline.return_value mock_generator.task = "text-generation" - mock_generator.return_value = [{"generated_text": "model-output"}] + mock_generator.return_value = [{"generated_text": [{"content": "model-output"}]}] return mock_generator @pytest.fixture(autouse=True) @@ -44,6 +44,16 @@ def test_try_run(self, prompt_stack): # Then assert text_artifact.value == "model-output" + def test_try_stream(self, prompt_stack): + # Given + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) + + # When + with pytest.raises(Exception) as e: + driver.try_stream(prompt_stack) + + assert e.value.args[0] == "streaming is not supported" + @pytest.mark.parametrize("choices", [[], [1, 2]]) def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack): # Given @@ -55,16 +65,16 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_gener driver.try_run(prompt_stack) # Then - e.value.args[0] == "completion with more than one choice is not supported yet" + assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_unsupported_task_returned(self, prompt_stack, mock_generator): + def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_generator.task = "obviously-an-unsupported-task" + mock_generator.return_value = {} # When with pytest.raises(Exception) as e: driver.try_run(prompt_stack) # Then - assert e.value.args[0].startswith("only models with the following tasks are supported: ") + assert e.value.args[0] == "invalid output format"