Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Dec 19, 2024
2 parents 925d259 + bfc8445 commit f1f7d77
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def prepare_input(self, input: "StandardInput") -> str:
Returns:
The prompt to send to the LLM.
"""
if self._pipeline.tokenizer.chat_template: # type: ignore
if self._pipeline.tokenizer.chat_template is None: # type: ignore
return input[0]["content"]

prompt: str = (
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/models/llms/huggingface/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ def test_model_name(self, transformers_llm: TransformersLLM) -> None:
== "distilabel-internal-testing/tiny-random-mistral"
)

def test_prepare_input(self, transformers_llm: TransformersLLM) -> None:
assert (
transformers_llm.prepare_input([{"role": "user", "content": "Hello"}])
== "<s> [INST] Hello [/INST]"
)

def test_prepare_input_no_chat_template(
self, transformers_llm: TransformersLLM
) -> None:
transformers_llm._pipeline.tokenizer.chat_template = None
assert (
transformers_llm.prepare_input([{"role": "user", "content": "Hello"}])
== "Hello"
)

def test_generate(self, transformers_llm: TransformersLLM) -> None:
responses = transformers_llm.generate(
inputs=[
Expand Down

0 comments on commit f1f7d77

Please sign in to comment.