diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index baedd8fb..e8d8864d 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -7,7 +7,16 @@ import os from pathlib import Path -from typing import Any, Dict, Optional, Union, Callable, Tuple, List +from typing import ( + Any, + Dict, + Optional, + Union, + Callable, + Tuple, + List, + Generator as GeneratorType, +) import logging @@ -304,24 +313,54 @@ def _extra_repr(self) -> str: s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}" return s + def _process_chunk(self, chunk: Any) -> GeneratorOutput: + """Process a single chunk of data using the output processors. + + Args: + chunk: Raw chunk data to process + + Returns: + Any: Processed chunk + str: Error string in case of an exception + """ + if not chunk or not self.output_processors: + return chunk, None + + try: + processed_data = self.output_processors(chunk) + return processed_data, None + except Exception as e: + log.error(f"Error processing chunk using the output processors: {e}") + return None, str(e) + def _post_call(self, completion: Any) -> GeneratorOutput: - r"""Get string completion and process it with the output_processors.""" - # parse chat completion will only fill the raw_response - output: GeneratorOutput = self.model_client.parse_chat_completion(completion) - # Now adding the data filed to the output - data = output.raw_response - if self.output_processors: - if data: + """Process completion output, handling both streaming and non-streaming cases. + + Args: + completion: Raw completion data from the llm provider + + Returns: + GeneratorOutput containing processed data or generator type + """ + # Parse chat completion will only fill the raw_response + output = self.model_client.parse_chat_completion(completion) + # Handle streaming case + if isinstance(output, GeneratorType): + + def process_stream(): try: - data = self.output_processors(data) - output.data = data + for out in output: + log.debug(f"Processing raw chunk: {out.raw_response}") + out.data, out.error = self._process_chunk(out.raw_response) + yield out except Exception as e: - log.error(f"Error processing the output processors: {e}") - output.error = str(e) + log.error(f"Error in stream processing: {e}") + yield GeneratorOutput(error=str(e)) + return GeneratorOutput(data=process_stream(), raw_response=output) else: - output.data = data - + # Handle non-streaming case + output.data, output.error = self._process_chunk(output.raw_response) return output def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]: diff --git a/adalflow/tests/test_generator.py b/adalflow/tests/test_generator.py index a15c302a..38d331ce 100644 --- a/adalflow/tests/test_generator.py +++ b/adalflow/tests/test_generator.py @@ -15,6 +15,7 @@ from adalflow.core.model_client import ModelClient from adalflow.components.model_client.groq_client import GroqAPIClient from adalflow.tracing import GeneratorStateLogger +from typing import Generator as GeneratorType class TestGenerator(IsolatedAsyncioTestCase): @@ -192,5 +193,72 @@ def test_groq_client_call(self, mock_call): # self.assertEqual(output.data, "Generated text response") +class TestGeneratorWithStream(unittest.TestCase): + + def setUp(self): + """Set up the mocked environment for the stream test cases.""" + self.sent_chunks = [ + "Pa", + "ris", + " is", + " the", + " capital", + " of", + " France", + ] + + with patch( + "adalflow.core.model_client.ModelClient", spec=ModelClient + ) as MockAPI: + mock_api_client = Mock(spec=ModelClient) + MockAPI.return_value = mock_api_client + + mock_api_client.convert_inputs_to_api_kwargs.return_value = { + "model": "phi3:latest", + "stream": True, + "prompt": ( + "\nYou are a helpful assistant.\n\n" + "\nWhat is the capital of France?\n" + ), + } + + mock_api_client.parse_chat_completion.return_value = ( + self._mock_stream_generator(self.sent_chunks) + ) + self.mock_api_client = mock_api_client + + self.generator = Generator(model_client=self.mock_api_client) + + def test_generator_call_with_stream(self): + """Test the generator call with streaming enabled.""" + prompt_kwargs = {"input_str": "What is the capital of France?"} + model_kwargs = {"model": "phi3:latest", "stream": True} + + output = self.generator.call( + prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs + ) + + # Assert that output is of type GeneratorOutput + self.assertIsInstance(output, GeneratorOutput) + # Assert that output.data is a generator type + self.assertIsInstance(output.data, GeneratorType) + + received_chunks = [] + for chunk in output.data: + # Assert that each chunk is of type GeneratorOutput + self.assertIsInstance(chunk, GeneratorOutput) + received_chunks.append(chunk.raw_response) + + # Assert that the received chunks match the sent chunks + self.assertEqual(received_chunks, self.sent_chunks) + + def _mock_stream_generator( + self, completion: list[str] + ) -> GeneratorType[GeneratorOutput, None, None]: + """Simulates streamed API responses.""" + for chunk in completion: + yield GeneratorOutput(data=None, raw_response=chunk) + + if __name__ == "__main__": unittest.main()