Skip to content

Commit

Permalink
feat: async support
Browse files Browse the repository at this point in the history
  • Loading branch information
fahreddinozcan committed Sep 30, 2024
1 parent 4f26ba2 commit 9d1caba
Show file tree
Hide file tree
Showing 2 changed files with 307 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import logging
from typing import List, Optional
from upstash_redis import Redis
from upstash_redis import Redis as SyncRedis
from upstash_redis.asyncio import Redis as AsyncRedis

from llama_index.core.bridge.pydantic import Field, PrivateAttr
import json

Expand Down Expand Up @@ -32,7 +34,9 @@ class UpstashChatStore(BaseChatStore):
for managing chat messages in an Upstash Redis database.
"""

_redis_client: Redis = PrivateAttr()
_sync_redis_client: SyncRedis = PrivateAttr()
_async_redis_client: AsyncRedis = PrivateAttr()

ttl: Optional[int] = Field(default=None, description="Time to live in seconds.")

class Config:
Expand All @@ -59,12 +63,13 @@ def __init__(
raise ValueError("Please provide a valid URL and token")

try:
self._redis_client = Redis(url=redis_url, token=redis_token)
self._sync_redis_client = SyncRedis(url=redis_url, token=redis_token)
self._async_redis_client = AsyncRedis(url=redis_url, token=redis_token)
except Exception as error:
logger.error(f"Upstash Redis client could not be initiated: {error}")

# self.ttl = ttl
super().__init__(_redis_client=self._redis_client, ttl=ttl)
super().__init__(ttl=ttl)

@classmethod
def class_name(self) -> str:
Expand All @@ -84,12 +89,27 @@ def set_messages(self, key: str, messages: List[ChatMessage]) -> None:
key (str): The key to store the messages under.
messages (List[ChatMessage]): The list of messages to store.
"""
self._redis_client.delete(key)
self._sync_redis_client.delete(key)
for message in messages:
self.add_message(key, message)

if self.ttl:
self._redis_client.expire(key, self.ttl)
self._sync_redis_client.expire(key, self.ttl)

async def async_set_messages(self, key: str, messages: List[ChatMessage]) -> None:
"""
Set messages for a key.
Args:
key (str): The key to store the messages under.
messages (List[ChatMessage]): The list of messages to store.
"""
await self._async_redis_client.delete(key)
for message in messages:
await self.async_add_message(key, message)

if self.ttl:
await self._async_redis_client.expire(key, self.ttl)

def get_messages(self, key: str) -> List[ChatMessage]:
"""
Expand All @@ -101,7 +121,23 @@ def get_messages(self, key: str) -> List[ChatMessage]:
Returns:
List[ChatMessage]: The list of retrieved messages.
"""
items = self._redis_client.lrange(key, 0, -1)
items = self._sync_redis_client.lrange(key, 0, -1)
if len(items) == 0:
return []

return [ChatMessage.parse_raw(item) for item in items]

async def async_get_messages(self, key: str) -> List[ChatMessage]:
"""
Get messages for a key.
Args:
key (str): The key to retrieve messages from.
Returns:
List[ChatMessage]: The list of retrieved messages.
"""
items = await self._async_redis_client.lrange(key, 0, -1)
if len(items) == 0:
return []

Expand All @@ -120,12 +156,32 @@ def add_message(
"""
if idx is None:
message_json = json.dumps(_message_to_dict(message))
self._redis_client.rpush(key, message_json)
self._sync_redis_client.rpush(key, message_json)
else:
self._insert_element_at_index(key, message, idx)

if self.ttl:
self._redis_client.expire(key, self.ttl)
self._sync_redis_client.expire(key, self.ttl)

async def async_add_message(
self, key: str, message: ChatMessage, idx: Optional[int] = None
) -> None:
"""
Add a message to a key.
Args:
key (str): The key to add the message to.
message (ChatMessage): The message to add.
idx (Optional[int]): The index at which to insert the message.
"""
if idx is None:
message_json = json.dumps(_message_to_dict(message))
await self._async_redis_client.rpush(key, message_json)
else:
await self._async_insert_element_at_index(key, message, idx)

if self.ttl:
await self._async_redis_client.expire(key, self.ttl)

def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
"""
Expand All @@ -137,7 +193,20 @@ def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
Returns:
Optional[List[ChatMessage]]: Always returns None in this implementation.
"""
self._redis_client.delete(key)
self._sync_redis_client.delete(key)
return None

async def async_delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
"""
Delete messages for a key.
Args:
key (str): The key to delete messages from.
Returns:
Optional[List[ChatMessage]]: Always returns None in this implementation.
"""
await self._async_redis_client.delete(key)
return None

def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
Expand All @@ -152,16 +221,45 @@ def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
Optional[ChatMessage]: The deleted message, or None if not found or an error occurred.
"""
try:
deleted_message = self._redis_client.lindex(key, idx)
deleted_message = self._sync_redis_client.lindex(key, idx)
if deleted_message is None:
return None

placeholder = f"{key}:{idx}:deleted"

self._redis_client.lset(key, idx, placeholder)
self._redis_client.lrem(key, 1, placeholder)
self._sync_redis_client.lset(key, idx, placeholder)
self._sync_redis_client.lrem(key, 1, placeholder)
if self.ttl:
self._redis_client.expire(key, self.ttl)
self._sync_redis_client.expire(key, self.ttl)

return deleted_message

except Exception as e:
logger.error(f"Error deleting message at index {idx} from {key}: {e}")
return None

async def async_delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
"""
Delete a message from a key.
Args:
key (str): The key to delete the message from.
idx (int): The index of the message to delete.
Returns:
Optional[ChatMessage]: The deleted message, or None if not found or an error occurred.
"""
try:
deleted_message = await self._async_redis_client.lindex(key, idx)
if deleted_message is None:
return None

placeholder = f"{key}:{idx}:deleted"

await self._async_redis_client.lset(key, idx, placeholder)
await self._async_redis_client.lrem(key, 1, placeholder)
if self.ttl:
await self._async_redis_client.expire(key, self.ttl)

return deleted_message

Expand All @@ -179,7 +277,23 @@ def delete_last_message(self, key: str) -> Optional[ChatMessage]:
Returns:
Optional[ChatMessage]: The deleted message, or None if the list is empty.
"""
return self._redis_client.rpop(key)
deleted_message = self._sync_redis_client.rpop(key)
return ChatMessage.parse_raw(deleted_message) if deleted_message else None

async def async_delete_last_message(self, key: str) -> Optional[ChatMessage]:
"""
Delete the last message from a key.
Args:
key (str): The key to delete the last message from.
Returns:
Optional[ChatMessage]: The deleted message, or None if the list is empty.
"""
deleted_message = await self._async_redis_client.rpop(key)
if deleted_message:
return ChatMessage.parse_raw(deleted_message)
return None

def get_keys(self) -> List[str]:
"""
Expand All @@ -188,7 +302,18 @@ def get_keys(self) -> List[str]:
Returns:
List[str]: A list of all keys in the Redis store.
"""
return [key.decode("utf-8") for key in self._redis_client.keys("*")]
keys = self._sync_redis_client.keys("*")
return keys if isinstance(keys, list) else [keys]

async def async_get_keys(self) -> List[str]:
"""
Get all keys.
Returns:
List[str]: A list of all keys in the Redis store.
"""
keys = await self._async_redis_client.keys("*")
return keys if isinstance(keys, list) else [keys]

def _insert_element_at_index(
self, key: str, message: ChatMessage, idx: int
Expand All @@ -207,8 +332,31 @@ def _insert_element_at_index(
current_list = self.get_messages(key)
current_list.insert(idx, message)

self._redis_client.delete(key)
self._sync_redis_client.delete(key)

self.set_messages(key, current_list)

return current_list

async def _async_insert_element_at_index(
self, key: str, message: ChatMessage, idx: int
) -> List[ChatMessage]:
"""
Insert a message at a specific index.
Args:
key (str): The key of the list to insert into.
message (ChatMessage): The message to insert.
idx (int): The index at which to insert the message.
Returns:
List[ChatMessage]: The updated list of messages.
"""
current_list = await self.async_get_messages(key)
current_list.insert(idx, message)

await self.async_delete_messages(key)

await self.async_set_messages(key, current_list)

return current_list
Loading

0 comments on commit 9d1caba

Please sign in to comment.