Skip to content

Commit

Permalink
fix: make redis client private attr
Browse files Browse the repository at this point in the history
  • Loading branch information
fahreddinozcan committed Sep 30, 2024
1 parent 0e856ec commit 4f26ba2
Showing 1 changed file with 21 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from llama_index.core.storage.chat_store.base import BaseChatStore

import logging
from typing import List, Optional, Any
from typing import List, Optional
from upstash_redis import Redis
from llama_index.core.bridge.pydantic import Field
from llama_index.core.bridge.pydantic import Field, PrivateAttr
import json

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,9 +32,12 @@ class UpstashChatStore(BaseChatStore):
for managing chat messages in an Upstash Redis database.
"""

redis_client: Any = Field(description="Redis client.")
_redis_client: Redis = PrivateAttr()
ttl: Optional[int] = Field(default=None, description="Time to live in seconds.")

class Config:
arbitrary_types_allowed = True

def __init__(
self,
redis_url: str = "",
Expand All @@ -56,12 +59,12 @@ def __init__(
raise ValueError("Please provide a valid URL and token")

try:
self.redis_client = Redis(url=redis_url, token=redis_token)
self._redis_client = Redis(url=redis_url, token=redis_token)
except Exception as error:
logger.error(f"Upstash Redis client could not be initiated: {error}")

# self.ttl = ttl
super().__init__(redis_client=self.redis_client, ttl=ttl)
super().__init__(_redis_client=self._redis_client, ttl=ttl)

@classmethod
def class_name(self) -> str:
Expand All @@ -81,12 +84,12 @@ def set_messages(self, key: str, messages: List[ChatMessage]) -> None:
key (str): The key to store the messages under.
messages (List[ChatMessage]): The list of messages to store.
"""
self.redis_client.delete(key)
self._redis_client.delete(key)
for message in messages:
self.add_message(key, message)

if self.ttl:
self.redis_client.expire(key, self.ttl)
self._redis_client.expire(key, self.ttl)

def get_messages(self, key: str) -> List[ChatMessage]:
"""
Expand All @@ -98,7 +101,7 @@ def get_messages(self, key: str) -> List[ChatMessage]:
Returns:
List[ChatMessage]: The list of retrieved messages.
"""
items = self.redis_client.lrange(key, 0, -1)
items = self._redis_client.lrange(key, 0, -1)
if len(items) == 0:
return []

Expand All @@ -117,12 +120,12 @@ def add_message(
"""
if idx is None:
message_json = json.dumps(_message_to_dict(message))
self.redis_client.rpush(key, message_json)
self._redis_client.rpush(key, message_json)
else:
self._insert_element_at_index(key, message, idx)

if self.ttl:
self.redis_client.expire(key, self.ttl)
self._redis_client.expire(key, self.ttl)

def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
"""
Expand All @@ -134,7 +137,7 @@ def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
Returns:
Optional[List[ChatMessage]]: Always returns None in this implementation.
"""
self.redis_client.delete(key)
self._redis_client.delete(key)
return None

def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
Expand All @@ -149,16 +152,16 @@ def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
Optional[ChatMessage]: The deleted message, or None if not found or an error occurred.
"""
try:
deleted_message = self.redis_client.lindex(key, idx)
deleted_message = self._redis_client.lindex(key, idx)
if deleted_message is None:
return None

placeholder = f"{key}:{idx}:deleted"

self.redis_client.lset(key, idx, placeholder)
self.redis_client.lrem(key, 1, placeholder)
self._redis_client.lset(key, idx, placeholder)
self._redis_client.lrem(key, 1, placeholder)
if self.ttl:
self.redis_client.expire(key, self.ttl)
self._redis_client.expire(key, self.ttl)

return deleted_message

Expand All @@ -176,7 +179,7 @@ def delete_last_message(self, key: str) -> Optional[ChatMessage]:
Returns:
Optional[ChatMessage]: The deleted message, or None if the list is empty.
"""
return self.redis_client.rpop(key)
return self._redis_client.rpop(key)

def get_keys(self) -> List[str]:
"""
Expand All @@ -185,7 +188,7 @@ def get_keys(self) -> List[str]:
Returns:
List[str]: A list of all keys in the Redis store.
"""
return [key.decode("utf-8") for key in self.redis_client.keys("*")]
return [key.decode("utf-8") for key in self._redis_client.keys("*")]

def _insert_element_at_index(
self, key: str, message: ChatMessage, idx: int
Expand All @@ -204,7 +207,7 @@ def _insert_element_at_index(
current_list = self.get_messages(key)
current_list.insert(idx, message)

self.redis_client.delete(key)
self._redis_client.delete(key)

self.set_messages(key, current_list)

Expand Down

0 comments on commit 4f26ba2

Please sign in to comment.