diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 74a6f550e..b7212e6ff 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -88,10 +88,8 @@ driver = HuggingFaceHubEmbeddingDriver( api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], model="sentence-transformers/all-MiniLM-L6-v2", tokenizer=HuggingFaceTokenizer( + model="sentence-transformers/all-MiniLM-L6-v2", max_output_tokens=512, - tokenizer=AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L6-v2" - ) ), ) diff --git a/docs/griptape-framework/misc/tokenizers.md b/docs/griptape-framework/misc/tokenizers.md index 64022e4aa..04b9ed9a6 100644 --- a/docs/griptape-framework/misc/tokenizers.md +++ b/docs/griptape-framework/misc/tokenizers.md @@ -69,8 +69,8 @@ from griptape.tokenizers import HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer( + model="sentence-transformers/all-MiniLM-L6-v2", max_output_tokens=512, - tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") ) print(tokenizer.count_tokens("Hello world!")) diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 54247d001..85b9f24bb 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -30,11 +30,7 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( - lambda self: HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), - max_output_tokens=self.max_tokens, - ), - takes_self=True, + lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True ), kw_only=True, ) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index c74e00609..aaf0af75a 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -42,11 +42,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): ) tokenizer: HuggingFaceTokenizer = field( default=Factory( - lambda self: HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), - max_output_tokens=self.max_tokens, - ), - takes_self=True, + lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True ), kw_only=True, ) diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 324ecd74f..6d1c1e4db 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -27,11 +27,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( - lambda self: HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), - max_output_tokens=self.max_tokens, - ), - takes_self=True, + lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True ), kw_only=True, ) diff --git a/griptape/tokenizers/huggingface_tokenizer.py b/griptape/tokenizers/huggingface_tokenizer.py index 663bd3cc0..1d861fbb3 100644 --- a/griptape/tokenizers/huggingface_tokenizer.py +++ b/griptape/tokenizers/huggingface_tokenizer.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from attrs import define, field, Factory -from griptape.utils import PromptStack +from griptape.utils import PromptStack, import_optional_dependency from griptape.tokenizers import BaseTokenizer if TYPE_CHECKING: @@ -10,12 +10,17 @@ @define() class HuggingFaceTokenizer(BaseTokenizer): - tokenizer: PreTrainedTokenizerBase = field(kw_only=True) - model: None = field(init=False, default=None, kw_only=True) + tokenizer: PreTrainedTokenizerBase = field( + default=Factory( + lambda self: import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), + takes_self=True, + ), + kw_only=True, + ) max_input_tokens: int = field( default=Factory(lambda self: self.tokenizer.model_max_length, takes_self=True), kw_only=True ) - max_output_tokens: int = field(kw_only=True) # pyright: ignore[reportGeneralTypeIssues] + max_output_tokens: int = field(default=4096, kw_only=True) def count_tokens(self, text: str | PromptStack) -> int: if isinstance(text, PromptStack): diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 05188d7ca..5bca48dba 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,5 +1,3 @@ -from transformers import AutoTokenizer - from griptape.drivers import OpenAiChatPromptDriver from griptape.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer from griptape.utils import PromptStack @@ -188,7 +186,7 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, - tokenizer=HuggingFaceTokenizer(tokenizer=AutoTokenizer.from_pretrained("gpt2"), max_output_tokens=1000), + tokenizer=HuggingFaceTokenizer(model="gpt2", max_output_tokens=1000), max_tokens=1, ) diff --git a/tests/unit/tokenizers/test_hugging_face_tokenizer.py b/tests/unit/tokenizers/test_hugging_face_tokenizer.py index e9e323735..e4bf169a5 100644 --- a/tests/unit/tokenizers/test_hugging_face_tokenizer.py +++ b/tests/unit/tokenizers/test_hugging_face_tokenizer.py @@ -4,14 +4,13 @@ import pytest # noqa: E402 from griptape.utils import PromptStack # noqa: E402 -from transformers import GPT2Tokenizer # noqa: E402 from griptape.tokenizers import HuggingFaceTokenizer # noqa: E402 class TestHuggingFaceTokenizer: @pytest.fixture def tokenizer(self): - return HuggingFaceTokenizer(tokenizer=GPT2Tokenizer.from_pretrained("gpt2"), max_output_tokens=1024) + return HuggingFaceTokenizer(model="gpt2", max_output_tokens=1024) def test_token_count(self, tokenizer): assert (