13
13
# limitations under the License.
14
14
15
15
import chromadb
16
+ import json
16
17
from typing_extensions import override
17
18
18
19
from ..events .event import Event
22
23
from .base_memory_service import SearchMemoryResponse
23
24
from google .genai import types
24
25
from ..events .event_actions import EventActions
26
+ from chromadb .config import Settings
27
+
25
28
26
29
class ChromaMemoryService (BaseMemoryService ):
27
30
"""A memory service that uses Chroma for storage and retrieval."""
28
- def __init__ (self , top_k : int = 10 ):
31
+
32
+ def __init__ (self , top_k : int = 10 , is_persistent : bool = False ):
29
33
"""Initializes a ChromaMemoryService.
30
34
Args:
31
35
collection_name: The name of the Chroma collection to use.
32
36
"""
33
- self .client = chromadb .EphemeralClient ()
37
+ settings = Settings (
38
+ allow_reset = True
39
+ )
40
+ self .client = chromadb .EphemeralClient (settings = settings ) if not is_persistent else chromadb .PersistentClient (settings = settings )
34
41
self .top_k = top_k
42
+
43
+ def _clean_metadata_value (self , value ):
44
+ if value is None :
45
+ return ""
46
+ if isinstance (value , (str , int , float , bool )):
47
+ return value
48
+ else :
49
+ try :
50
+ return json .dumps (value )
51
+ except :
52
+ return str (value )
53
+
54
+ @staticmethod
55
+ def _parse_metadata_field_value (value , type_caster = None , is_json_load = False ):
56
+ """Helper to parse metadata values for Event reconstruction."""
57
+ if value == "" :
58
+ return None
59
+ if value is None :
60
+ return None
61
+
62
+ if is_json_load :
63
+ if isinstance (value , str ):
64
+ try :
65
+ parsed_value = json .loads (value )
66
+ return parsed_value
67
+ except json .JSONDecodeError as e :
68
+ raise ValueError (f"Failed to parse JSON string: { e } " )
69
+
70
+ return value
71
+
72
+ if type_caster :
73
+ try :
74
+ casted_value = type_caster (value )
75
+ return casted_value
76
+ except ValueError as e :
77
+ return None
78
+
79
+ return value
80
+
35
81
@override
36
- def add_session_to_memory (self , session : Session ):
82
+ async def add_session_to_memory (self , session : Session ):
37
83
"""Adds a session to the memory service.
38
84
39
85
Args:
@@ -48,7 +94,7 @@ def add_session_to_memory(self, session: Session):
48
94
text_parts = []
49
95
for part in event .content .parts :
50
96
if part .text :
51
- text_parts .append (part .text )
97
+ text_parts .append (part .text . lower () )
52
98
53
99
if not text_parts :
54
100
continue
@@ -59,26 +105,26 @@ def add_session_to_memory(self, session: Session):
59
105
ids = [f"{ session .id } _{ event .id } " ],
60
106
documents = [text ],
61
107
metadatas = [{
62
- "session_id" : session .id ,
63
- "event_id" : event .id ,
64
- "invocation_id" : event .invocation_id ,
65
- "author" : event .author ,
66
- "timestamp" : event .timestamp ,
67
- "branch" : event .branch or "" ,
68
- "actions" : event .actions .model_dump_json () if event .actions else None ,
69
- "long_running_tool_ids" : str (list (event .long_running_tool_ids )) if event .long_running_tool_ids else None ,
70
- "grounding_metadata" : event .grounding_metadata .model_dump_json () if event .grounding_metadata else None ,
71
- "partial" : event .partial ,
72
- "turn_complete" : event .turn_complete ,
73
- "error_code" : event .error_code ,
74
- "error_message" : event .error_message ,
75
- "interrupted" : event .interrupted ,
76
- "custom_metadata" : event .custom_metadata ,
108
+ "session_id" : self . _clean_metadata_value ( session .id ) ,
109
+ "event_id" : self . _clean_metadata_value ( event .id ) ,
110
+ "invocation_id" : self . _clean_metadata_value ( event .invocation_id ) ,
111
+ "author" : self . _clean_metadata_value ( event .author ) ,
112
+ "timestamp" : self . _clean_metadata_value ( event .timestamp ) ,
113
+ "branch" : self . _clean_metadata_value ( event .branch or "" ) ,
114
+ "actions" : self . _clean_metadata_value ( event .actions .model_dump_json () if event .actions else None ) ,
115
+ "long_running_tool_ids" : self . _clean_metadata_value ( str (list (event .long_running_tool_ids )) if event .long_running_tool_ids else None ) ,
116
+ "grounding_metadata" : self . _clean_metadata_value ( event .grounding_metadata .model_dump_json () if event .grounding_metadata else None ) ,
117
+ "partial" : self . _clean_metadata_value ( event .partial ) ,
118
+ "turn_complete" : self . _clean_metadata_value ( event .turn_complete ) ,
119
+ "error_code" : self . _clean_metadata_value ( event .error_code ) ,
120
+ "error_message" : self . _clean_metadata_value ( event .error_message ) ,
121
+ "interrupted" : self . _clean_metadata_value ( event .interrupted ) ,
122
+ "custom_metadata" : self . _clean_metadata_value ( event .custom_metadata ) ,
77
123
}],
78
124
)
79
125
80
126
@override
81
- def search_memory (
127
+ async def search_memory (
82
128
self , * , app_name : str , user_id : str , query : str
83
129
) -> SearchMemoryResponse :
84
130
"""Searches for sessions that match the query using both semantic and keyword search.
@@ -96,40 +142,67 @@ def search_memory(
96
142
except Exception as e :
97
143
return SearchMemoryResponse (memories = [])
98
144
99
- # Perform keyword search
145
+ # Perform hybrid search
100
146
keywords = set (query .lower ().split ())
147
+ if len (keywords ) == 0 :
148
+ where_document = {}
149
+ elif len (keywords ) == 1 :
150
+ keyword_val = keywords .pop ()
151
+ where_document = {"$contains" : keyword_val }
152
+ else :
153
+ where_document = {"$or" : [{"$contains" : keyword } for keyword in keywords ]}
154
+
101
155
results = collection .query (
102
- query_texts = [query ],
156
+ query_texts = [query ] if query else None ,
103
157
n_results = self .top_k ,
104
- where_document = {"$or" : [
105
- {"$contains" : keyword }
106
- for keyword in keywords
107
- ]}
158
+ where_document = where_document if where_document else None ,
159
+ include = ['metadatas' , 'documents' ]
108
160
)
109
161
110
162
session_events = {}
111
- for i , doc_id in enumerate (results ["ids" ][0 ]):
112
- session_id = doc_id .split ("_" )[0 ]
113
- event_id = doc_id .split ("_" )[1 ]
163
+ if not (results and results .get ("ids" ) and results ["ids" ] and \
164
+ isinstance (results ["ids" ][0 ], list ) and results ["ids" ][0 ]):
165
+ return SearchMemoryResponse (memories = [])
166
+
167
+ doc_ids = results ["ids" ][0 ]
168
+ all_metadatas = results .get ("metadatas" , [[]])[0 ]
169
+ all_documents = results .get ("documents" , [[]])[0 ]
170
+
171
+ if not (len (doc_ids ) == len (all_metadatas ) == len (all_documents )):
172
+ return SearchMemoryResponse (memories = [])
173
+
174
+ for i , doc_id_str in enumerate (doc_ids ):
175
+ if not isinstance (doc_id_str , str ):
176
+ continue
177
+
178
+ try :
179
+ id_parts = doc_id_str .rsplit ('_' , 1 )
180
+ if len (id_parts ) == 2 :
181
+ session_id , event_id = id_parts [0 ], id_parts [1 ]
182
+ else :
183
+ continue
184
+ except Exception as e :
185
+ continue
186
+
187
+ current_metadata = all_metadatas [i ]
188
+ current_document_text = all_documents [i ]
114
189
115
-
116
- metadata = results ["metadatas" ][0 ][i ]
117
190
event = Event (
118
- id = event_id ,
119
- invocation_id = metadata [ "invocation_id" ] ,
120
- author = metadata [ "author" ] ,
121
- timestamp = metadata [ "timestamp" ] ,
122
- content = types .Content (parts = [types .Part (text = results [ "documents" ][ 0 ][ i ] )]),
123
- branch = metadata [ "branch" ] if metadata [ "branch" ] else None ,
124
- actions = EventActions .model_validate_json (metadata ["actions" ]) if metadata [ "actions" ] else None ,
125
- long_running_tool_ids = set (eval (metadata ["long_running_tool_ids" ])) if metadata [ "long_running_tool_ids" ] else None ,
126
- grounding_metadata = types .GroundingMetadata .model_validate_json (metadata ["grounding_metadata" ]) if metadata [ "grounding_metadata" ] else None ,
127
- partial = metadata [ "partial" ] ,
128
- turn_complete = metadata [ "turn_complete" ] ,
129
- error_code = metadata [ "error_code" ] ,
130
- error_message = metadata [ "error_message" ] ,
131
- interrupted = metadata [ "interrupted" ] ,
132
- custom_metadata = metadata [ "custom_metadata" ]
191
+ id = event_id , # Use the correctly parsed event_id
192
+ invocation_id = current_metadata . get ( "invocation_id" ) ,
193
+ author = current_metadata . get ( "author" ) ,
194
+ timestamp = ChromaMemoryService . _parse_metadata_field_value ( current_metadata . get ( "timestamp" ), type_caster = float ) ,
195
+ content = types .Content (parts = [types .Part (text = current_document_text )]),
196
+ branch = current_metadata . get ( "branch" ) if current_metadata . get ( "branch" ) else None ,
197
+ actions = EventActions .model_validate_json (current_metadata ["actions" ]) if current_metadata . get ( "actions" ) else None ,
198
+ long_running_tool_ids = set (eval (current_metadata ["long_running_tool_ids" ])) if current_metadata . get ( "long_running_tool_ids" ) else None ,
199
+ grounding_metadata = types .GroundingMetadata .model_validate_json (current_metadata ["grounding_metadata" ]) if current_metadata . get ( "grounding_metadata" ) else None ,
200
+ partial = ChromaMemoryService . _parse_metadata_field_value ( current_metadata . get ( "partial" )) ,
201
+ turn_complete = ChromaMemoryService . _parse_metadata_field_value ( current_metadata . get ( "turn_complete" )) ,
202
+ error_code = ChromaMemoryService . _parse_metadata_field_value ( current_metadata . get ( "error_code" ), type_caster = int ) ,
203
+ error_message = current_metadata . get ( "error_message" ) ,
204
+ interrupted = ChromaMemoryService . _parse_metadata_field_value ( current_metadata . get ( "interrupted" )) ,
205
+ custom_metadata = ChromaMemoryService . _parse_metadata_field_value ( current_metadata . get ( "custom_metadata" ), is_json_load = True )
133
206
)
134
207
135
208
if session_id not in session_events :
@@ -146,5 +219,3 @@ def search_memory(
146
219
)
147
220
148
221
return SearchMemoryResponse (memories = memory_results )
149
-
150
-
0 commit comments