diff --git a/thenewboston/discord/bot.py b/thenewboston/discord/bot.py index 41e64ce..479a297 100644 --- a/thenewboston/discord/bot.py +++ b/thenewboston/discord/bot.py @@ -13,6 +13,9 @@ from thenewboston.general.clients.llm import LLMClient, make_prompt_kwargs # noqa: E402 +USER_ROLE = 'user' +ASSISTANT_ROLE = 'assistant' + logger = logging.getLogger(__name__) intents = discord.Intents.default() @@ -32,7 +35,7 @@ def map_author_plaintext(author): def map_author_structured(author): - return 'assistant' if is_ia(author) else 'user' + return ASSISTANT_ROLE if is_ia(author) else USER_ROLE def messages_to_plaintext(messages): @@ -56,6 +59,9 @@ def messages_to_structured(messages): prev_role = role + if structured_messages and structured_messages[0]['role'] != USER_ROLE: + structured_messages.pop(0) + return structured_messages diff --git a/thenewboston/discord/tests/test_bot.py b/thenewboston/discord/tests/test_bot.py index 33a0ca7..a6f87c0 100644 --- a/thenewboston/discord/tests/test_bot.py +++ b/thenewboston/discord/tests/test_bot.py @@ -16,38 +16,38 @@ async def test_on_ready(): await on_ready() -@override_settings(IA_DISCORD_USER_ID=1234) +@override_settings(IA_DISCORD_USER_ID=1) def test_messages_to_structured(): - assert messages_to_structured([Message(author=Author(id=1234), content='hello')]) == [{ - 'role': 'assistant', + assert messages_to_structured([Message(author=Author(id=2), content='hello')]) == [{ + 'role': 'user', 'content': [{ 'type': 'text', 'text': 'hello' }] }] assert messages_to_structured([ - Message(author=Author(id=1234), content='hello'), - Message(author=Author(id=1234), content='world') + Message(author=Author(id=2), content='hello'), + Message(author=Author(id=2), content='world') ]) == [{ - 'role': 'assistant', + 'role': 'user', 'content': [{ 'type': 'text', 'text': 'hello\nworld' }] }] assert messages_to_structured([ - Message(author=Author(id=1234), content='hello'), - Message(author=Author(id=10), content='world') + Message(author=Author(id=2), content='hello'), + Message(author=Author(id=1), content='world') ]) == [ { - 'role': 'assistant', + 'role': 'user', 'content': [{ 'type': 'text', 'text': 'hello' }] }, { - 'role': 'user', + 'role': 'assistant', 'content': [{ 'type': 'text', 'text': 'world' @@ -55,26 +55,26 @@ def test_messages_to_structured(): }, ] assert messages_to_structured([ - Message(author=Author(id=1234), content='hello'), - Message(author=Author(id=10), content='world'), - Message(author=Author(id=1234), content='bye') + Message(author=Author(id=2), content='hello'), + Message(author=Author(id=1), content='world'), + Message(author=Author(id=2), content='bye') ]) == [ { - 'role': 'assistant', + 'role': 'user', 'content': [{ 'type': 'text', 'text': 'hello' }] }, { - 'role': 'user', + 'role': 'assistant', 'content': [{ 'type': 'text', 'text': 'world' }] }, { - 'role': 'assistant', + 'role': 'user', 'content': [{ 'type': 'text', 'text': 'bye' @@ -82,30 +82,41 @@ def test_messages_to_structured(): }, ] assert messages_to_structured([ - Message(author=Author(id=1234), content='hello'), - Message(author=Author(id=10), content='world'), - Message(author=Author(id=10), content='mine'), - Message(author=Author(id=1234), content='bye') + Message(author=Author(id=2), content='hello'), + Message(author=Author(id=1), content='world'), + Message(author=Author(id=2), content='mine'), + Message(author=Author(id=2), content='bye') ]) == [ { - 'role': 'assistant', + 'role': 'user', 'content': [{ 'type': 'text', 'text': 'hello' }] }, { - 'role': 'user', + 'role': 'assistant', 'content': [{ 'type': 'text', - 'text': 'world\nmine' + 'text': 'world' }] }, { - 'role': 'assistant', + 'role': 'user', 'content': [{ 'type': 'text', - 'text': 'bye' + 'text': 'mine\nbye' }] }, ] + + assert messages_to_structured([ + Message(author=Author(id=1), content='hello'), + Message(author=Author(id=2), content='world') + ]) == [{ + 'role': 'user', + 'content': [{ + 'type': 'text', + 'text': 'world' + }] + }]