Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip improved object retrieval #10513

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion llama-index-core/llama_index/core/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base schema for data structures."""

import json
import textwrap
import uuid
Expand Down Expand Up @@ -499,7 +500,26 @@ class IndexNode(TextNode):
"""

index_id: str
obj: Any = Field(exclude=True)
obj: Any = None

def dict(self, **kwargs: Any) -> Dict[str, Any]:
from llama_index.core.storage.docstore.utils import doc_to_json

data = super().dict(**kwargs)

try:
if self.obj is None:
data["obj"] = None
elif isinstance(self.obj, BaseNode):
data["obj"] = doc_to_json(self.obj)
elif isinstance(self.obj, BaseModel):
data["obj"] = self.obj.dict()
else:
data["obj"] = json.dumps(self.obj)
except Exception:
raise ValueError("IndexNode obj is not serializable: " + str(self.obj))

return data

@classmethod
def from_text_node(
Expand All @@ -514,6 +534,28 @@ def from_text_node(
index_id=index_id,
)

# TODO: return type here not supported by current mypy version
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
output = super().from_dict(data, **kwargs)

obj = data.get("obj", None)
parsed_obj = None
if isinstance(obj, str):
parsed_obj = TextNode(text=obj)
elif isinstance(obj, dict):
from llama_index.core.storage.docstore.utils import json_to_doc

# check if its a node, else assume stringable
try:
parsed_obj = json_to_doc(obj)
except Exception:
parsed_obj = TextNode(text=str(obj))

output.obj = parsed_obj

return output

@classmethod
def get_type(cls) -> str:
return ObjectType.INDEX
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/vector_stores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def metadata_dict_to_node(metadata: dict, text: Optional[str] = None) -> BaseNod

node: BaseNode
if node_type == IndexNode.class_name():
node = IndexNode.parse_raw(node_json)
node = IndexNode.from_json(node_json)
elif node_type == ImageNode.class_name():
node = ImageNode.parse_raw(node_json)
node = ImageNode.from_json(node_json)
else:
node = TextNode.parse_raw(node_json)
node = TextNode.from_json(node_json)

if text is not None:
node.set_content(text)
Expand Down
4 changes: 4 additions & 0 deletions llama-index-integrations/llms/llama-index-llms-litellm/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
poetry_requirements(
name="poetry",
)

python_sources(
interpreter_constraints=["==3.9.*", "==3.10.*"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ readme = "README.md"
version = "0.1.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
python = ">=3.9,<3.12"
llama-index-core = "^0.10.1"
litellm = "^1.18.13"

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
python_tests()
python_tests(
interpreter_constraints=["==3.9.*", "==3.10.*"],
)
3 changes: 2 additions & 1 deletion llama-index-legacy/llama_index/legacy/core/base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _retrieve_from_object(
f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",
color="llama_pink",
)

if isinstance(obj, NodeWithScore):
return [obj]
elif isinstance(obj, BaseNode):
Expand Down Expand Up @@ -149,7 +150,7 @@ def _handle_recursive_retrieval(
node = n.node
score = n.score or 1.0
if isinstance(node, IndexNode):
obj = self.object_map.get(node.index_id, None)
obj = node.obj or self.object_map.get(node.index_id, None)
if obj is not None:
if self._verbose:
print_text(
Expand Down
6 changes: 5 additions & 1 deletion llama-index-legacy/llama_index/legacy/indices/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(
self._graph_store = self._storage_context.graph_store

objects = objects or []
self._object_map = {obj.index_id: obj.obj for obj in objects}
self._object_map = {}
for obj in objects:
self._object_map[obj.index_id] = obj.obj
obj.obj = None # clear the object avoid serialization issues

with self._service_context.callback_manager.as_trace("index_construction"):
if index_struct is None:
nodes = nodes or []
Expand Down
Loading