Skip to content

Commit 016910d

Browse files
committed
remove comments and update CLA
1 parent 8f46a9a commit 016910d

File tree

2 files changed

+360
-48
lines changed

2 files changed

+360
-48
lines changed

src/google/adk/memory/chroma_memory_service.py

Lines changed: 119 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import chromadb
16+
import json
1617
from typing_extensions import override
1718

1819
from ..events.event import Event
@@ -22,18 +23,63 @@
2223
from .base_memory_service import SearchMemoryResponse
2324
from google.genai import types
2425
from ..events.event_actions import EventActions
26+
from chromadb.config import Settings
27+
2528

2629
class ChromaMemoryService(BaseMemoryService):
2730
"""A memory service that uses Chroma for storage and retrieval."""
28-
def __init__(self, top_k: int = 10):
31+
32+
def __init__(self, top_k: int = 10, is_persistent: bool = False):
2933
"""Initializes a ChromaMemoryService.
3034
Args:
3135
collection_name: The name of the Chroma collection to use.
3236
"""
33-
self.client = chromadb.EphemeralClient()
37+
settings = Settings(
38+
allow_reset=True
39+
)
40+
self.client = chromadb.EphemeralClient(settings=settings) if not is_persistent else chromadb.PersistentClient(settings=settings)
3441
self.top_k = top_k
42+
43+
def _clean_metadata_value(self, value):
44+
if value is None:
45+
return ""
46+
if isinstance(value, (str, int, float, bool)):
47+
return value
48+
else:
49+
try:
50+
return json.dumps(value)
51+
except:
52+
return str(value)
53+
54+
@staticmethod
55+
def _parse_metadata_field_value(value, type_caster=None, is_json_load=False):
56+
"""Helper to parse metadata values for Event reconstruction."""
57+
if value == "":
58+
return None
59+
if value is None:
60+
return None
61+
62+
if is_json_load:
63+
if isinstance(value, str):
64+
try:
65+
parsed_value = json.loads(value)
66+
return parsed_value
67+
except json.JSONDecodeError as e:
68+
raise ValueError(f"Failed to parse JSON string: {e}")
69+
70+
return value
71+
72+
if type_caster:
73+
try:
74+
casted_value = type_caster(value)
75+
return casted_value
76+
except ValueError as e:
77+
return None
78+
79+
return value
80+
3581
@override
36-
def add_session_to_memory(self, session: Session):
82+
async def add_session_to_memory(self, session: Session):
3783
"""Adds a session to the memory service.
3884
3985
Args:
@@ -48,7 +94,7 @@ def add_session_to_memory(self, session: Session):
4894
text_parts = []
4995
for part in event.content.parts:
5096
if part.text:
51-
text_parts.append(part.text)
97+
text_parts.append(part.text.lower())
5298

5399
if not text_parts:
54100
continue
@@ -59,26 +105,26 @@ def add_session_to_memory(self, session: Session):
59105
ids=[f"{session.id}_{event.id}"],
60106
documents=[text],
61107
metadatas=[{
62-
"session_id": session.id,
63-
"event_id": event.id,
64-
"invocation_id": event.invocation_id,
65-
"author": event.author,
66-
"timestamp": event.timestamp,
67-
"branch": event.branch or "",
68-
"actions": event.actions.model_dump_json() if event.actions else None,
69-
"long_running_tool_ids": str(list(event.long_running_tool_ids)) if event.long_running_tool_ids else None,
70-
"grounding_metadata": event.grounding_metadata.model_dump_json() if event.grounding_metadata else None,
71-
"partial": event.partial,
72-
"turn_complete": event.turn_complete,
73-
"error_code": event.error_code,
74-
"error_message": event.error_message,
75-
"interrupted": event.interrupted,
76-
"custom_metadata": event.custom_metadata,
108+
"session_id": self._clean_metadata_value(session.id),
109+
"event_id": self._clean_metadata_value(event.id),
110+
"invocation_id": self._clean_metadata_value(event.invocation_id),
111+
"author": self._clean_metadata_value(event.author),
112+
"timestamp": self._clean_metadata_value(event.timestamp),
113+
"branch": self._clean_metadata_value(event.branch or ""),
114+
"actions": self._clean_metadata_value(event.actions.model_dump_json() if event.actions else None),
115+
"long_running_tool_ids": self._clean_metadata_value(str(list(event.long_running_tool_ids)) if event.long_running_tool_ids else None),
116+
"grounding_metadata": self._clean_metadata_value(event.grounding_metadata.model_dump_json() if event.grounding_metadata else None),
117+
"partial": self._clean_metadata_value(event.partial),
118+
"turn_complete": self._clean_metadata_value(event.turn_complete),
119+
"error_code": self._clean_metadata_value(event.error_code),
120+
"error_message": self._clean_metadata_value(event.error_message),
121+
"interrupted": self._clean_metadata_value(event.interrupted),
122+
"custom_metadata": self._clean_metadata_value(event.custom_metadata),
77123
}],
78124
)
79125

80126
@override
81-
def search_memory(
127+
async def search_memory(
82128
self, *, app_name: str, user_id: str, query: str
83129
) -> SearchMemoryResponse:
84130
"""Searches for sessions that match the query using both semantic and keyword search.
@@ -96,40 +142,67 @@ def search_memory(
96142
except Exception as e:
97143
return SearchMemoryResponse(memories=[])
98144

99-
# Perform keyword search
145+
# Perform hybrid search
100146
keywords = set(query.lower().split())
147+
if len(keywords) == 0:
148+
where_document = {}
149+
elif len(keywords) == 1:
150+
keyword_val = keywords.pop()
151+
where_document = {"$contains": keyword_val}
152+
else:
153+
where_document = {"$or": [{"$contains": keyword} for keyword in keywords]}
154+
101155
results = collection.query(
102-
query_texts=[query],
156+
query_texts=[query] if query else None,
103157
n_results=self.top_k,
104-
where_document={"$or": [
105-
{"$contains": keyword}
106-
for keyword in keywords
107-
]}
158+
where_document=where_document if where_document else None,
159+
include=['metadatas', 'documents']
108160
)
109161

110162
session_events = {}
111-
for i, doc_id in enumerate(results["ids"][0]):
112-
session_id = doc_id.split("_")[0]
113-
event_id = doc_id.split("_")[1]
163+
if not (results and results.get("ids") and results["ids"] and \
164+
isinstance(results["ids"][0], list) and results["ids"][0]):
165+
return SearchMemoryResponse(memories=[])
166+
167+
doc_ids = results["ids"][0]
168+
all_metadatas = results.get("metadatas", [[]])[0]
169+
all_documents = results.get("documents", [[]])[0]
170+
171+
if not (len(doc_ids) == len(all_metadatas) == len(all_documents)):
172+
return SearchMemoryResponse(memories=[])
173+
174+
for i, doc_id_str in enumerate(doc_ids):
175+
if not isinstance(doc_id_str, str):
176+
continue
177+
178+
try:
179+
id_parts = doc_id_str.rsplit('_', 1)
180+
if len(id_parts) == 2:
181+
session_id, event_id = id_parts[0], id_parts[1]
182+
else:
183+
continue
184+
except Exception as e:
185+
continue
186+
187+
current_metadata = all_metadatas[i]
188+
current_document_text = all_documents[i]
114189

115-
116-
metadata = results["metadatas"][0][i]
117190
event = Event(
118-
id=event_id,
119-
invocation_id=metadata["invocation_id"],
120-
author=metadata["author"],
121-
timestamp=metadata["timestamp"],
122-
content=types.Content(parts=[types.Part(text=results["documents"][0][i])]),
123-
branch=metadata["branch"] if metadata["branch"] else None,
124-
actions=EventActions.model_validate_json(metadata["actions"]) if metadata["actions"] else None,
125-
long_running_tool_ids=set(eval(metadata["long_running_tool_ids"])) if metadata["long_running_tool_ids"] else None,
126-
grounding_metadata=types.GroundingMetadata.model_validate_json(metadata["grounding_metadata"]) if metadata["grounding_metadata"] else None,
127-
partial=metadata["partial"],
128-
turn_complete=metadata["turn_complete"],
129-
error_code=metadata["error_code"],
130-
error_message=metadata["error_message"],
131-
interrupted=metadata["interrupted"],
132-
custom_metadata=metadata["custom_metadata"]
191+
id=event_id, # Use the correctly parsed event_id
192+
invocation_id=current_metadata.get("invocation_id"),
193+
author=current_metadata.get("author"),
194+
timestamp=ChromaMemoryService._parse_metadata_field_value(current_metadata.get("timestamp"), type_caster=float),
195+
content=types.Content(parts=[types.Part(text=current_document_text)]),
196+
branch=current_metadata.get("branch") if current_metadata.get("branch") else None,
197+
actions=EventActions.model_validate_json(current_metadata["actions"]) if current_metadata.get("actions") else None,
198+
long_running_tool_ids=set(eval(current_metadata["long_running_tool_ids"])) if current_metadata.get("long_running_tool_ids") else None,
199+
grounding_metadata=types.GroundingMetadata.model_validate_json(current_metadata["grounding_metadata"]) if current_metadata.get("grounding_metadata") else None,
200+
partial=ChromaMemoryService._parse_metadata_field_value(current_metadata.get("partial")),
201+
turn_complete=ChromaMemoryService._parse_metadata_field_value(current_metadata.get("turn_complete")),
202+
error_code=ChromaMemoryService._parse_metadata_field_value(current_metadata.get("error_code"), type_caster=int),
203+
error_message=current_metadata.get("error_message"),
204+
interrupted=ChromaMemoryService._parse_metadata_field_value(current_metadata.get("interrupted")),
205+
custom_metadata=ChromaMemoryService._parse_metadata_field_value(current_metadata.get("custom_metadata"), is_json_load=True)
133206
)
134207

135208
if session_id not in session_events:
@@ -146,5 +219,3 @@ def search_memory(
146219
)
147220

148221
return SearchMemoryResponse(memories=memory_results)
149-
150-

0 commit comments

Comments
 (0)