Skip to content

Commit

Permalink
wip improved object retrieval (#10513)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Feb 16, 2024
1 parent 827feb1 commit 3546490
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 8 deletions.
44 changes: 43 additions & 1 deletion llama-index-core/llama_index/core/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base schema for data structures."""

import json
import textwrap
import uuid
Expand Down Expand Up @@ -499,7 +500,26 @@ class IndexNode(TextNode):
"""

index_id: str
obj: Any = Field(exclude=True)
obj: Any = None

def dict(self, **kwargs: Any) -> Dict[str, Any]:
from llama_index.core.storage.docstore.utils import doc_to_json

data = super().dict(**kwargs)

try:
if self.obj is None:
data["obj"] = None
elif isinstance(self.obj, BaseNode):
data["obj"] = doc_to_json(self.obj)
elif isinstance(self.obj, BaseModel):
data["obj"] = self.obj.dict()
else:
data["obj"] = json.dumps(self.obj)
except Exception:
raise ValueError("IndexNode obj is not serializable: " + str(self.obj))

return data

@classmethod
def from_text_node(
Expand All @@ -514,6 +534,28 @@ def from_text_node(
index_id=index_id,
)

# TODO: return type here not supported by current mypy version
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
output = super().from_dict(data, **kwargs)

obj = data.get("obj", None)
parsed_obj = None
if isinstance(obj, str):
parsed_obj = TextNode(text=obj)
elif isinstance(obj, dict):
from llama_index.core.storage.docstore.utils import json_to_doc

# check if its a node, else assume stringable
try:
parsed_obj = json_to_doc(obj)
except Exception:
parsed_obj = TextNode(text=str(obj))

output.obj = parsed_obj

return output

@classmethod
def get_type(cls) -> str:
return ObjectType.INDEX
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/vector_stores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def metadata_dict_to_node(metadata: dict, text: Optional[str] = None) -> BaseNod

node: BaseNode
if node_type == IndexNode.class_name():
node = IndexNode.parse_raw(node_json)
node = IndexNode.from_json(node_json)
elif node_type == ImageNode.class_name():
node = ImageNode.parse_raw(node_json)
node = ImageNode.from_json(node_json)
else:
node = TextNode.parse_raw(node_json)
node = TextNode.from_json(node_json)

if text is not None:
node.set_content(text)
Expand Down
4 changes: 4 additions & 0 deletions llama-index-integrations/llms/llama-index-llms-litellm/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
poetry_requirements(
name="poetry",
)

python_sources(
interpreter_constraints=["==3.9.*", "==3.10.*"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ readme = "README.md"
version = "0.1.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
python = ">=3.9,<3.12"
llama-index-core = "^0.10.1"
litellm = "^1.18.13"

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
python_tests()
python_tests(
interpreter_constraints=["==3.9.*", "==3.10.*"],
)
3 changes: 2 additions & 1 deletion llama-index-legacy/llama_index/legacy/core/base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _retrieve_from_object(
f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",
color="llama_pink",
)

if isinstance(obj, NodeWithScore):
return [obj]
elif isinstance(obj, BaseNode):
Expand Down Expand Up @@ -149,7 +150,7 @@ def _handle_recursive_retrieval(
node = n.node
score = n.score or 1.0
if isinstance(node, IndexNode):
obj = self.object_map.get(node.index_id, None)
obj = node.obj or self.object_map.get(node.index_id, None)
if obj is not None:
if self._verbose:
print_text(
Expand Down
6 changes: 5 additions & 1 deletion llama-index-legacy/llama_index/legacy/indices/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(
self._graph_store = self._storage_context.graph_store

objects = objects or []
self._object_map = {obj.index_id: obj.obj for obj in objects}
self._object_map = {}
for obj in objects:
self._object_map[obj.index_id] = obj.obj
obj.obj = None # clear the object avoid serialization issues

with self._service_context.callback_manager.as_trace("index_construction"):
if index_struct is None:
nodes = nodes or []
Expand Down

0 comments on commit 3546490

Please sign in to comment.