diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py index 432df0c4f618f..4541c1987b04f 100644 --- a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py @@ -3,7 +3,9 @@ import logging from typing import List, Optional -from upstash_redis import Redis +from upstash_redis import Redis as SyncRedis +from upstash_redis.asyncio import Redis as AsyncRedis + from llama_index.core.bridge.pydantic import Field, PrivateAttr import json @@ -32,7 +34,9 @@ class UpstashChatStore(BaseChatStore): for managing chat messages in an Upstash Redis database. """ - _redis_client: Redis = PrivateAttr() + _sync_redis_client: SyncRedis = PrivateAttr() + _async_redis_client: AsyncRedis = PrivateAttr() + ttl: Optional[int] = Field(default=None, description="Time to live in seconds.") class Config: @@ -59,12 +63,13 @@ def __init__( raise ValueError("Please provide a valid URL and token") try: - self._redis_client = Redis(url=redis_url, token=redis_token) + self._sync_redis_client = SyncRedis(url=redis_url, token=redis_token) + self._async_redis_client = AsyncRedis(url=redis_url, token=redis_token) except Exception as error: logger.error(f"Upstash Redis client could not be initiated: {error}") # self.ttl = ttl - super().__init__(_redis_client=self._redis_client, ttl=ttl) + super().__init__(ttl=ttl) @classmethod def class_name(self) -> str: @@ -84,12 +89,27 @@ def set_messages(self, key: str, messages: List[ChatMessage]) -> None: key (str): The key to store the messages under. messages (List[ChatMessage]): The list of messages to store. """ - self._redis_client.delete(key) + self._sync_redis_client.delete(key) for message in messages: self.add_message(key, message) if self.ttl: - self._redis_client.expire(key, self.ttl) + self._sync_redis_client.expire(key, self.ttl) + + async def async_set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """ + Set messages for a key. + + Args: + key (str): The key to store the messages under. + messages (List[ChatMessage]): The list of messages to store. + """ + await self._async_redis_client.delete(key) + for message in messages: + await self.async_add_message(key, message) + + if self.ttl: + await self._async_redis_client.expire(key, self.ttl) def get_messages(self, key: str) -> List[ChatMessage]: """ @@ -101,7 +121,23 @@ def get_messages(self, key: str) -> List[ChatMessage]: Returns: List[ChatMessage]: The list of retrieved messages. """ - items = self._redis_client.lrange(key, 0, -1) + items = self._sync_redis_client.lrange(key, 0, -1) + if len(items) == 0: + return [] + + return [ChatMessage.parse_raw(item) for item in items] + + async def async_get_messages(self, key: str) -> List[ChatMessage]: + """ + Get messages for a key. + + Args: + key (str): The key to retrieve messages from. + + Returns: + List[ChatMessage]: The list of retrieved messages. + """ + items = await self._async_redis_client.lrange(key, 0, -1) if len(items) == 0: return [] @@ -120,12 +156,32 @@ def add_message( """ if idx is None: message_json = json.dumps(_message_to_dict(message)) - self._redis_client.rpush(key, message_json) + self._sync_redis_client.rpush(key, message_json) else: self._insert_element_at_index(key, message, idx) if self.ttl: - self._redis_client.expire(key, self.ttl) + self._sync_redis_client.expire(key, self.ttl) + + async def async_add_message( + self, key: str, message: ChatMessage, idx: Optional[int] = None + ) -> None: + """ + Add a message to a key. + + Args: + key (str): The key to add the message to. + message (ChatMessage): The message to add. + idx (Optional[int]): The index at which to insert the message. + """ + if idx is None: + message_json = json.dumps(_message_to_dict(message)) + await self._async_redis_client.rpush(key, message_json) + else: + await self._async_insert_element_at_index(key, message, idx) + + if self.ttl: + await self._async_redis_client.expire(key, self.ttl) def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: """ @@ -137,7 +193,20 @@ def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: Returns: Optional[List[ChatMessage]]: Always returns None in this implementation. """ - self._redis_client.delete(key) + self._sync_redis_client.delete(key) + return None + + async def async_delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """ + Delete messages for a key. + + Args: + key (str): The key to delete messages from. + + Returns: + Optional[List[ChatMessage]]: Always returns None in this implementation. + """ + await self._async_redis_client.delete(key) return None def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: @@ -152,16 +221,45 @@ def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: Optional[ChatMessage]: The deleted message, or None if not found or an error occurred. """ try: - deleted_message = self._redis_client.lindex(key, idx) + deleted_message = self._sync_redis_client.lindex(key, idx) if deleted_message is None: return None placeholder = f"{key}:{idx}:deleted" - self._redis_client.lset(key, idx, placeholder) - self._redis_client.lrem(key, 1, placeholder) + self._sync_redis_client.lset(key, idx, placeholder) + self._sync_redis_client.lrem(key, 1, placeholder) if self.ttl: - self._redis_client.expire(key, self.ttl) + self._sync_redis_client.expire(key, self.ttl) + + return deleted_message + + except Exception as e: + logger.error(f"Error deleting message at index {idx} from {key}: {e}") + return None + + async def async_delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """ + Delete a message from a key. + + Args: + key (str): The key to delete the message from. + idx (int): The index of the message to delete. + + Returns: + Optional[ChatMessage]: The deleted message, or None if not found or an error occurred. + """ + try: + deleted_message = await self._async_redis_client.lindex(key, idx) + if deleted_message is None: + return None + + placeholder = f"{key}:{idx}:deleted" + + await self._async_redis_client.lset(key, idx, placeholder) + await self._async_redis_client.lrem(key, 1, placeholder) + if self.ttl: + await self._async_redis_client.expire(key, self.ttl) return deleted_message @@ -179,7 +277,23 @@ def delete_last_message(self, key: str) -> Optional[ChatMessage]: Returns: Optional[ChatMessage]: The deleted message, or None if the list is empty. """ - return self._redis_client.rpop(key) + deleted_message = self._sync_redis_client.rpop(key) + return ChatMessage.parse_raw(deleted_message) if deleted_message else None + + async def async_delete_last_message(self, key: str) -> Optional[ChatMessage]: + """ + Delete the last message from a key. + + Args: + key (str): The key to delete the last message from. + + Returns: + Optional[ChatMessage]: The deleted message, or None if the list is empty. + """ + deleted_message = await self._async_redis_client.rpop(key) + if deleted_message: + return ChatMessage.parse_raw(deleted_message) + return None def get_keys(self) -> List[str]: """ @@ -188,7 +302,18 @@ def get_keys(self) -> List[str]: Returns: List[str]: A list of all keys in the Redis store. """ - return [key.decode("utf-8") for key in self._redis_client.keys("*")] + keys = self._sync_redis_client.keys("*") + return keys if isinstance(keys, list) else [keys] + + async def async_get_keys(self) -> List[str]: + """ + Get all keys. + + Returns: + List[str]: A list of all keys in the Redis store. + """ + keys = await self._async_redis_client.keys("*") + return keys if isinstance(keys, list) else [keys] def _insert_element_at_index( self, key: str, message: ChatMessage, idx: int @@ -207,8 +332,31 @@ def _insert_element_at_index( current_list = self.get_messages(key) current_list.insert(idx, message) - self._redis_client.delete(key) + self._sync_redis_client.delete(key) self.set_messages(key, current_list) return current_list + + async def _async_insert_element_at_index( + self, key: str, message: ChatMessage, idx: int + ) -> List[ChatMessage]: + """ + Insert a message at a specific index. + + Args: + key (str): The key of the list to insert into. + message (ChatMessage): The message to insert. + idx (int): The index at which to insert the message. + + Returns: + List[ChatMessage]: The updated list of messages. + """ + current_list = await self.async_get_messages(key) + current_list.insert(idx, message) + + await self.async_delete_messages(key) + + await self.async_set_messages(key, current_list) + + return current_list diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/tests/test_chat_store_upstash_chat_store.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/tests/test_chat_store_upstash_chat_store.py index 107d07da24d8f..3e653a0f1c079 100644 --- a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/tests/test_chat_store_upstash_chat_store.py +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/tests/test_chat_store_upstash_chat_store.py @@ -4,13 +4,13 @@ from importlib.util import find_spec from llama_index.core.llms import ChatMessage import time - +import asyncio import logging -# Configure logging at the top of your test file logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) +logger = logging.getLogger(__name__) def dict_to_message(d: dict) -> ChatMessage: @@ -37,6 +37,9 @@ def upstash_chat_store() -> UpstashChatStore: ) +#################### +#### SYNC TESTS #### +#################### @pytest.mark.skip(reason="Skipping all tests") def test_invalid_initialization(): with pytest.raises(ValueError): @@ -138,3 +141,140 @@ def test_add_message_at_index(upstash_chat_store: UpstashChatStore): assert result_messages[0].content == "First message" assert result_messages[1].content == "Second message" assert result_messages[2].content == "Third message" + + +##################### +#### ASYNC TESTS #### +##################### + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_upstash_add_message(upstash_chat_store: UpstashChatStore): + key = "test_async_add_key" + + message = ChatMessage(content="async_add_message_test", role="user") + await upstash_chat_store.async_add_message(key, message=message) + + result = await upstash_chat_store.async_get_messages(key) + + assert result[0].content == "async_add_message_test" and result[0].role == "user" + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_set_and_retrieve_messages(upstash_chat_store: UpstashChatStore): + messages = [ + ChatMessage(content="First async message", role="user"), + ChatMessage(content="Second async message", role="user"), + ] + key = "test_async_set_key" + await upstash_chat_store.async_set_messages(key, messages) + + retrieved_messages = await upstash_chat_store.async_get_messages(key) + assert len(retrieved_messages) == 2 + assert retrieved_messages[0].content == "First async message" + assert retrieved_messages[1].content == "Second async message" + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_delete_messages(upstash_chat_store: UpstashChatStore): + messages = [ChatMessage(content="Async message to delete", role="user")] + key = "test_async_delete_key" + await upstash_chat_store.async_set_messages(key, messages) + + await upstash_chat_store.async_delete_messages(key) + retrieved_messages = await upstash_chat_store.async_get_messages(key) + assert retrieved_messages == [] + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_delete_specific_message(upstash_chat_store: UpstashChatStore): + messages = [ + ChatMessage(content="Async keep me", role="user"), + ChatMessage(content="Async delete me", role="user"), + ] + key = "test_async_delete_message_key" + await upstash_chat_store.async_set_messages(key, messages) + + await upstash_chat_store.async_delete_message(key, 1) + retrieved_messages = await upstash_chat_store.async_get_messages(key) + assert len(retrieved_messages) == 1 + assert retrieved_messages[0].content == "Async keep me" + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_ttl_on_messages(upstash_chat_store: UpstashChatStore): + upstash_chat_store.ttl = 3 + key = "async_ttl_test_key" + message = ChatMessage(content="This async message will expire", role="user") + await upstash_chat_store.async_add_message(key, message) + + await asyncio.sleep(4) # Waiting for the ttl to expire. + + retrieved_messages = await upstash_chat_store.async_get_messages(key) + assert retrieved_messages == [] + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_add_message_at_index(upstash_chat_store: UpstashChatStore): + key = "test_async_add_message_index_key" + # Clear any existing data for the key + await upstash_chat_store.async_delete_messages(key) + + initial_messages = [ + ChatMessage(content="First async message", role="user"), + ChatMessage(content="Third async message", role="user"), + ] + + await upstash_chat_store.async_set_messages(key, initial_messages) + + new_message = ChatMessage(content="Second async message", role="user") + await upstash_chat_store.async_add_message(key, new_message, idx=1) + + # Retrieve messages to check the order + result_messages = await upstash_chat_store.async_get_messages(key) + assert len(result_messages) == 3 + assert result_messages[0].content == "First async message" + assert result_messages[1].content == "Second async message" + assert result_messages[2].content == "Third async message" + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_get_keys(upstash_chat_store: UpstashChatStore): + # Add some test data + await upstash_chat_store.async_set_messages( + "async_key1", [ChatMessage(content="Test1", role="user")] + ) + await upstash_chat_store.async_set_messages( + "async_key2", [ChatMessage(content="Test2", role="user")] + ) + + keys = await upstash_chat_store.async_get_keys() + assert "async_key1" in keys + assert "async_key2" in keys + + +@pytest.mark.skip(reason="Skipping all tests") +@pytest.mark.asyncio() +async def test_async_delete_last_message(upstash_chat_store: UpstashChatStore): + key = "test_async_delete_last_message" + messages = [ + ChatMessage(content="First async message", role="user"), + ChatMessage(content="Last async message", role="user"), + ] + await upstash_chat_store.async_set_messages(key, messages) + + deleted_message = await upstash_chat_store.async_delete_last_message(key) + + assert deleted_message.content == "Last async message" + + remaining_messages = await upstash_chat_store.async_get_messages(key) + + assert len(remaining_messages) == 1 + assert remaining_messages[0].content == "First async message"