diff --git a/gptcli/gpt.py b/gptcli/gpt.py index 28e5111..19710cf 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -240,8 +240,7 @@ def __init__(self, assistant: Assistant, markdown: bool, show_price: bool, strea listeners.append(PriceChatListener(assistant)) listener = CompositeChatListener(listeners) - self.stream = stream - super().__init__(assistant, listener) + super().__init__(assistant, listener, stream) def run_interactive(args, assistant): diff --git a/gptcli/session.py b/gptcli/session.py index 581146c..e107c55 100644 --- a/gptcli/session.py +++ b/gptcli/session.py @@ -81,11 +81,13 @@ def __init__( self, assistant: Assistant, listener: ChatListener, + stream: bool = True, ): self.assistant = assistant self.messages: List[Message] = assistant.init_messages() self.user_prompts: List[Tuple[Message, ModelOverrides]] = [] self.listener = listener + self.stream = stream def _clear(self): self.messages = self.assistant.init_messages() diff --git a/tests/test_session.py b/tests/test_session.py index c2e2496..a72b98d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -40,7 +40,7 @@ def test_simple_input(): assistant_message = {"role": "assistant", "content": expected_response} assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={} + [system_message, user_message], override_params={}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [mock.call(user_message), mock.call(assistant_message)] @@ -66,7 +66,7 @@ def test_clear(): assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message"}], - override_params={}, + override_params={}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -93,7 +93,7 @@ def test_clear(): assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message_1"}], - override_params={}, + override_params={}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -128,7 +128,7 @@ def test_rerun(): assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message"}], - override_params={}, + override_params={}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -150,7 +150,7 @@ def test_rerun(): assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message"}], - override_params={}, + override_params={}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -175,7 +175,7 @@ def test_args(): assistant_message = {"role": "assistant", "content": expected_response} assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={"arg1": "value1"} + [system_message, user_message], override_params={"arg1": "value1"}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [mock.call(user_message), mock.call(assistant_message)] @@ -191,7 +191,7 @@ def test_args(): assert should_continue assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={"arg1": "value1"} + [system_message, user_message], override_params={"arg1": "value1"}, stream=True, ) listener_mock.on_chat_message.assert_has_calls([mock.call(assistant_message)]) @@ -250,7 +250,7 @@ def test_openai_error(): assert should_continue assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={} + [system_message, user_message], override_params={}, stream=True, ) listener_mock.on_chat_message.assert_has_calls( [