From 4f26ba27cc5370bc6294f72a2c84a73a803d714e Mon Sep 17 00:00:00 2001 From: fahreddinozcan Date: Mon, 30 Sep 2024 13:59:18 +0300 Subject: [PATCH] fix: make redis client private attr --- .../storage/chat_store/upstash/base.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py index ab82f5449b224..432df0c4f618f 100644 --- a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-upstash/llama_index/storage/chat_store/upstash/base.py @@ -2,9 +2,9 @@ from llama_index.core.storage.chat_store.base import BaseChatStore import logging -from typing import List, Optional, Any +from typing import List, Optional from upstash_redis import Redis -from llama_index.core.bridge.pydantic import Field +from llama_index.core.bridge.pydantic import Field, PrivateAttr import json logger = logging.getLogger(__name__) @@ -32,9 +32,12 @@ class UpstashChatStore(BaseChatStore): for managing chat messages in an Upstash Redis database. """ - redis_client: Any = Field(description="Redis client.") + _redis_client: Redis = PrivateAttr() ttl: Optional[int] = Field(default=None, description="Time to live in seconds.") + class Config: + arbitrary_types_allowed = True + def __init__( self, redis_url: str = "", @@ -56,12 +59,12 @@ def __init__( raise ValueError("Please provide a valid URL and token") try: - self.redis_client = Redis(url=redis_url, token=redis_token) + self._redis_client = Redis(url=redis_url, token=redis_token) except Exception as error: logger.error(f"Upstash Redis client could not be initiated: {error}") # self.ttl = ttl - super().__init__(redis_client=self.redis_client, ttl=ttl) + super().__init__(_redis_client=self._redis_client, ttl=ttl) @classmethod def class_name(self) -> str: @@ -81,12 +84,12 @@ def set_messages(self, key: str, messages: List[ChatMessage]) -> None: key (str): The key to store the messages under. messages (List[ChatMessage]): The list of messages to store. """ - self.redis_client.delete(key) + self._redis_client.delete(key) for message in messages: self.add_message(key, message) if self.ttl: - self.redis_client.expire(key, self.ttl) + self._redis_client.expire(key, self.ttl) def get_messages(self, key: str) -> List[ChatMessage]: """ @@ -98,7 +101,7 @@ def get_messages(self, key: str) -> List[ChatMessage]: Returns: List[ChatMessage]: The list of retrieved messages. """ - items = self.redis_client.lrange(key, 0, -1) + items = self._redis_client.lrange(key, 0, -1) if len(items) == 0: return [] @@ -117,12 +120,12 @@ def add_message( """ if idx is None: message_json = json.dumps(_message_to_dict(message)) - self.redis_client.rpush(key, message_json) + self._redis_client.rpush(key, message_json) else: self._insert_element_at_index(key, message, idx) if self.ttl: - self.redis_client.expire(key, self.ttl) + self._redis_client.expire(key, self.ttl) def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: """ @@ -134,7 +137,7 @@ def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: Returns: Optional[List[ChatMessage]]: Always returns None in this implementation. """ - self.redis_client.delete(key) + self._redis_client.delete(key) return None def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: @@ -149,16 +152,16 @@ def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: Optional[ChatMessage]: The deleted message, or None if not found or an error occurred. """ try: - deleted_message = self.redis_client.lindex(key, idx) + deleted_message = self._redis_client.lindex(key, idx) if deleted_message is None: return None placeholder = f"{key}:{idx}:deleted" - self.redis_client.lset(key, idx, placeholder) - self.redis_client.lrem(key, 1, placeholder) + self._redis_client.lset(key, idx, placeholder) + self._redis_client.lrem(key, 1, placeholder) if self.ttl: - self.redis_client.expire(key, self.ttl) + self._redis_client.expire(key, self.ttl) return deleted_message @@ -176,7 +179,7 @@ def delete_last_message(self, key: str) -> Optional[ChatMessage]: Returns: Optional[ChatMessage]: The deleted message, or None if the list is empty. """ - return self.redis_client.rpop(key) + return self._redis_client.rpop(key) def get_keys(self) -> List[str]: """ @@ -185,7 +188,7 @@ def get_keys(self) -> List[str]: Returns: List[str]: A list of all keys in the Redis store. """ - return [key.decode("utf-8") for key in self.redis_client.keys("*")] + return [key.decode("utf-8") for key in self._redis_client.keys("*")] def _insert_element_at_index( self, key: str, message: ChatMessage, idx: int @@ -204,7 +207,7 @@ def _insert_element_at_index( current_list = self.get_messages(key) current_list.insert(idx, message) - self.redis_client.delete(key) + self._redis_client.delete(key) self.set_messages(key, current_list)