Skip to content

Commit

Permalink
Simplify HuggingFaceTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 7, 2024
1 parent 50b7432 commit 15045bd
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 28 deletions.
4 changes: 1 addition & 3 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ driver = HuggingFaceHubEmbeddingDriver(
api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"],
model="sentence-transformers/all-MiniLM-L6-v2",
tokenizer=HuggingFaceTokenizer(
model="sentence-transformers/all-MiniLM-L6-v2",
max_output_tokens=512,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
)
),
)

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/tokenizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ from griptape.tokenizers import HuggingFaceTokenizer


tokenizer = HuggingFaceTokenizer(
model="sentence-transformers/all-MiniLM-L6-v2",
max_output_tokens=512,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
)

print(tokenizer.count_tokens("Hello world!"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(
tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
max_output_tokens=self.max_tokens,
),
takes_self=True,
lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True
),
kw_only=True,
)
Expand Down
6 changes: 1 addition & 5 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
)
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(
tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
max_output_tokens=self.max_tokens,
),
takes_self=True,
lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True
),
kw_only=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(
tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
max_output_tokens=self.max_tokens,
),
takes_self=True,
lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True
),
kw_only=True,
)
Expand Down
13 changes: 9 additions & 4 deletions griptape/tokenizers/huggingface_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from attrs import define, field, Factory
from griptape.utils import PromptStack
from griptape.utils import PromptStack, import_optional_dependency
from griptape.tokenizers import BaseTokenizer

if TYPE_CHECKING:
Expand All @@ -10,12 +10,17 @@

@define()
class HuggingFaceTokenizer(BaseTokenizer):
tokenizer: PreTrainedTokenizerBase = field(kw_only=True)
model: None = field(init=False, default=None, kw_only=True)
tokenizer: PreTrainedTokenizerBase = field(
default=Factory(
lambda self: import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
takes_self=True,
),
kw_only=True,
)
max_input_tokens: int = field(
default=Factory(lambda self: self.tokenizer.model_max_length, takes_self=True), kw_only=True
)
max_output_tokens: int = field(kw_only=True) # pyright: ignore[reportGeneralTypeIssues]
max_output_tokens: int = field(default=4096, kw_only=True)

def count_tokens(self, text: str | PromptStack) -> int:
if isinstance(text, PromptStack):
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from transformers import AutoTokenizer

from griptape.drivers import OpenAiChatPromptDriver
from griptape.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
from griptape.utils import PromptStack
Expand Down Expand Up @@ -188,7 +186,7 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_
def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages):
driver = OpenAiChatPromptDriver(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,
tokenizer=HuggingFaceTokenizer(tokenizer=AutoTokenizer.from_pretrained("gpt2"), max_output_tokens=1000),
tokenizer=HuggingFaceTokenizer(model="gpt2", max_output_tokens=1000),
max_tokens=1,
)

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/tokenizers/test_hugging_face_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import pytest # noqa: E402
from griptape.utils import PromptStack # noqa: E402
from transformers import GPT2Tokenizer # noqa: E402
from griptape.tokenizers import HuggingFaceTokenizer # noqa: E402


class TestHuggingFaceTokenizer:
@pytest.fixture
def tokenizer(self):
return HuggingFaceTokenizer(tokenizer=GPT2Tokenizer.from_pretrained("gpt2"), max_output_tokens=1024)
return HuggingFaceTokenizer(model="gpt2", max_output_tokens=1024)

def test_token_count(self, tokenizer):
assert (
Expand Down

0 comments on commit 15045bd

Please sign in to comment.