Skip to content

Commit

Permalink
Fix InferenceEndpointsLLM unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Dec 12, 2024
1 parent f35fd8b commit 86ee1dd
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 8 deletions.
12 changes: 7 additions & 5 deletions src/distilabel/models/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,23 +437,25 @@ async def _generate_with_text_generation(
if generation and generation.details
else 0
],
logprobs=[self._get_logprobs_from_text_generation(generation)]
logprobs=self._get_logprobs_from_text_generation(generation)
if generation
else None, # type: ignore
)

def _get_logprobs_from_text_generation(
self, generation: "TextGenerationOutput"
) -> Union[List[List["Logprob"]], None]:
) -> Union[List[List[List["Logprob"]]], None]:
if generation.details is None or generation.details.top_tokens is None:
return None

return [
[
{"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
for top_logprob in token_logprobs
[
{"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
for top_logprob in token_logprobs
]
for token_logprobs in generation.details.top_tokens
]
for token_logprobs in generation.details.top_tokens
]

async def _generate_with_chat_completion(
Expand Down
169 changes: 166 additions & 3 deletions tests/unit/models/llms/huggingface/test_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from huggingface_hub import (
ChatCompletionOutput,
ChatCompletionOutputComplete,
ChatCompletionOutputLogprob,
ChatCompletionOutputLogprobs,
ChatCompletionOutputMessage,
ChatCompletionOutputTopLogprob,
ChatCompletionOutputUsage,
)

Expand Down Expand Up @@ -134,6 +137,7 @@ async def test_agenerate_with_text_generation(
generated_text="Aenean hendrerit aliquam velit...",
details=MagicMock(
generated_tokens=66,
top_tokens=None,
),
)
)
Expand All @@ -146,12 +150,72 @@ async def test_agenerate_with_text_generation(
},
]
)

assert result == {
"generations": ["Aenean hendrerit aliquam velit..."],
"statistics": {
"input_tokens": [31],
"output_tokens": [66],
},
}

@pytest.mark.asyncio
async def test_agenerate_with_text_generation_and_top_n_tokens(
self, mock_inference_client: MagicMock
) -> None:
llm = InferenceEndpointsLLM(
model_id="distilabel-internal-testing/tiny-random-mistral",
tokenizer_id="distilabel-internal-testing/tiny-random-mistral",
)
llm.load()

llm._aclient.text_generation = AsyncMock(
return_value=MagicMock(
generated_text="Aenean hendrerit aliquam velit...",
details=MagicMock(
generated_tokens=66,
top_tokens=[
[
{"logprob": 0, "text": "Aenean"},
{"logprob": -2, "text": "Hello"},
],
[
{"logprob": 0, "text": " "},
{"logprob": -2, "text": ","},
],
],
),
)
)

result = await llm.agenerate(
input=[
{
"role": "user",
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
},
],
top_n_tokens=2,
)

assert result == {
"generations": ["Aenean hendrerit aliquam velit..."],
"statistics": {
"input_tokens": [31],
"output_tokens": [66],
},
"logprobs": [
[
[
{"logprob": 0, "token": "Aenean"},
{"logprob": -2, "token": "Hello"},
],
[
{"logprob": 0, "token": " "},
{"logprob": -2, "token": ","},
],
]
],
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -201,6 +265,107 @@ async def test_agenerate_with_chat_completion(
},
}

@pytest.mark.asyncio
async def test_agenerate_with_chat_completion_and_logprobs_and_top_logprobs(
self, mock_inference_client: MagicMock
) -> None:
llm = InferenceEndpointsLLM(
model_id="distilabel-internal-testing/tiny-random-mistral",
)
llm.load()

llm._aclient.chat_completion = AsyncMock( # type: ignore
return_value=ChatCompletionOutput( # type: ignore
choices=[
ChatCompletionOutputComplete(
finish_reason="length",
index=0,
message=ChatCompletionOutputMessage(
role="assistant",
content=" Aenean hendrerit aliquam velit. ...",
),
logprobs=ChatCompletionOutputLogprobs(
content=[
ChatCompletionOutputLogprob(
logprob=0,
token=" ",
top_logprobs=[
ChatCompletionOutputTopLogprob(
logprob=0, token=" "
),
ChatCompletionOutputTopLogprob(
logprob=-1, token="Hello"
),
],
),
ChatCompletionOutputLogprob(
logprob=0,
token="Aenean",
top_logprobs=[
ChatCompletionOutputTopLogprob(
logprob=0, token="Aenean"
),
ChatCompletionOutputTopLogprob(
logprob=-1, token="miau"
),
],
),
]
),
)
],
created=1721045246,
id="",
model="meta-llama/Meta-Llama-3-70B-Instruct",
system_fingerprint="2.1.1-dev0-sha-4327210",
usage=ChatCompletionOutputUsage(
completion_tokens=66, prompt_tokens=18, total_tokens=84
),
)
)

result = await llm.agenerate(
input=[
{
"role": "user",
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
},
],
logprobs=True,
top_logprobs=2,
)
assert result == {
"generations": [" Aenean hendrerit aliquam velit. ..."],
"statistics": {
"input_tokens": [18],
"output_tokens": [66],
},
"logprobs": [
[
[
{
"logprob": 0,
"token": " ",
},
{
"logprob": -1,
"token": "Hello",
},
],
[
{
"logprob": 0,
"token": "Aenean",
},
{
"logprob": -1,
"token": "miau",
},
],
]
],
}

@pytest.mark.asyncio
async def test_agenerate_with_chat_completion_fails(
self, mock_inference_client: MagicMock
Expand Down Expand Up @@ -338,9 +503,7 @@ async def test_agenerate_with_structured_output(
llm._aclient.text_generation = AsyncMock(
return_value=MagicMock(
generated_text="Aenean hendrerit aliquam velit...",
details=MagicMock(
generated_tokens=66,
),
details=MagicMock(generated_tokens=66, top_tokens=None),
)
)
# Since there's a pseudo-random number within the generation kwargs, we set the seed
Expand Down

0 comments on commit 86ee1dd

Please sign in to comment.