From 86ee1ddcfc8063333662587eebe58f1a2042c00d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 12 Dec 2024 10:42:16 +0100 Subject: [PATCH] Fix `InferenceEndpointsLLM` unit tests --- .../llms/huggingface/inference_endpoints.py | 12 +- .../huggingface/test_inference_endpoints.py | 169 +++++++++++++++++- 2 files changed, 173 insertions(+), 8 deletions(-) diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index e69dfaf49c..613c533b01 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -437,23 +437,25 @@ async def _generate_with_text_generation( if generation and generation.details else 0 ], - logprobs=[self._get_logprobs_from_text_generation(generation)] + logprobs=self._get_logprobs_from_text_generation(generation) if generation else None, # type: ignore ) def _get_logprobs_from_text_generation( self, generation: "TextGenerationOutput" - ) -> Union[List[List["Logprob"]], None]: + ) -> Union[List[List[List["Logprob"]]], None]: if generation.details is None or generation.details.top_tokens is None: return None return [ [ - {"token": top_logprob["text"], "logprob": top_logprob["logprob"]} - for top_logprob in token_logprobs + [ + {"token": top_logprob["text"], "logprob": top_logprob["logprob"]} + for top_logprob in token_logprobs + ] + for token_logprobs in generation.details.top_tokens ] - for token_logprobs in generation.details.top_tokens ] async def _generate_with_chat_completion( diff --git a/tests/unit/models/llms/huggingface/test_inference_endpoints.py b/tests/unit/models/llms/huggingface/test_inference_endpoints.py index 874cd9a595..f1dcd5e028 100644 --- a/tests/unit/models/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/models/llms/huggingface/test_inference_endpoints.py @@ -23,7 +23,10 @@ from huggingface_hub import ( ChatCompletionOutput, ChatCompletionOutputComplete, + ChatCompletionOutputLogprob, + ChatCompletionOutputLogprobs, ChatCompletionOutputMessage, + ChatCompletionOutputTopLogprob, ChatCompletionOutputUsage, ) @@ -134,6 +137,7 @@ async def test_agenerate_with_text_generation( generated_text="Aenean hendrerit aliquam velit...", details=MagicMock( generated_tokens=66, + top_tokens=None, ), ) ) @@ -146,12 +150,72 @@ async def test_agenerate_with_text_generation( }, ] ) + + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": { + "input_tokens": [31], + "output_tokens": [66], + }, + } + + @pytest.mark.asyncio + async def test_agenerate_with_text_generation_and_top_n_tokens( + self, mock_inference_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", + ) + llm.load() + + llm._aclient.text_generation = AsyncMock( + return_value=MagicMock( + generated_text="Aenean hendrerit aliquam velit...", + details=MagicMock( + generated_tokens=66, + top_tokens=[ + [ + {"logprob": 0, "text": "Aenean"}, + {"logprob": -2, "text": "Hello"}, + ], + [ + {"logprob": 0, "text": " "}, + {"logprob": -2, "text": ","}, + ], + ], + ), + ) + ) + + result = await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ], + top_n_tokens=2, + ) + assert result == { "generations": ["Aenean hendrerit aliquam velit..."], "statistics": { "input_tokens": [31], "output_tokens": [66], }, + "logprobs": [ + [ + [ + {"logprob": 0, "token": "Aenean"}, + {"logprob": -2, "token": "Hello"}, + ], + [ + {"logprob": 0, "token": " "}, + {"logprob": -2, "token": ","}, + ], + ] + ], } @pytest.mark.asyncio @@ -201,6 +265,107 @@ async def test_agenerate_with_chat_completion( }, } + @pytest.mark.asyncio + async def test_agenerate_with_chat_completion_and_logprobs_and_top_logprobs( + self, mock_inference_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + ) + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="length", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=" Aenean hendrerit aliquam velit. ...", + ), + logprobs=ChatCompletionOutputLogprobs( + content=[ + ChatCompletionOutputLogprob( + logprob=0, + token=" ", + top_logprobs=[ + ChatCompletionOutputTopLogprob( + logprob=0, token=" " + ), + ChatCompletionOutputTopLogprob( + logprob=-1, token="Hello" + ), + ], + ), + ChatCompletionOutputLogprob( + logprob=0, + token="Aenean", + top_logprobs=[ + ChatCompletionOutputTopLogprob( + logprob=0, token="Aenean" + ), + ChatCompletionOutputTopLogprob( + logprob=-1, token="miau" + ), + ], + ), + ] + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) + ) + + result = await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ], + logprobs=True, + top_logprobs=2, + ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": [18], + "output_tokens": [66], + }, + "logprobs": [ + [ + [ + { + "logprob": 0, + "token": " ", + }, + { + "logprob": -1, + "token": "Hello", + }, + ], + [ + { + "logprob": 0, + "token": "Aenean", + }, + { + "logprob": -1, + "token": "miau", + }, + ], + ] + ], + } + @pytest.mark.asyncio async def test_agenerate_with_chat_completion_fails( self, mock_inference_client: MagicMock @@ -338,9 +503,7 @@ async def test_agenerate_with_structured_output( llm._aclient.text_generation = AsyncMock( return_value=MagicMock( generated_text="Aenean hendrerit aliquam velit...", - details=MagicMock( - generated_tokens=66, - ), + details=MagicMock(generated_tokens=66, top_tokens=None), ) ) # Since there's a pseudo-random number within the generation kwargs, we set the seed