From b8dd2719bf3edd90160a01b04151de11a2ed0e8c Mon Sep 17 00:00:00 2001 From: Logan Date: Fri, 16 Feb 2024 17:16:54 -0600 Subject: [PATCH] wip improved object retrieval (#10513) --- llama-index-core/llama_index/core/schema.py | 44 ++++++++++++++++++- .../llama_index/core/vector_stores/utils.py | 6 +-- .../llms/llama-index-llms-litellm/BUILD | 4 ++ .../llama-index-llms-litellm/pyproject.toml | 2 +- .../llms/llama-index-llms-litellm/tests/BUILD | 4 +- .../llama_index/legacy/core/base_retriever.py | 3 +- .../llama_index/legacy/indices/base.py | 6 ++- 7 files changed, 61 insertions(+), 8 deletions(-) diff --git a/llama-index-core/llama_index/core/schema.py b/llama-index-core/llama_index/core/schema.py index 9eac9bef74da1..4ecda5b9475b9 100644 --- a/llama-index-core/llama_index/core/schema.py +++ b/llama-index-core/llama_index/core/schema.py @@ -1,4 +1,5 @@ """Base schema for data structures.""" + import json import textwrap import uuid @@ -499,7 +500,26 @@ class IndexNode(TextNode): """ index_id: str - obj: Any = Field(exclude=True) + obj: Any = None + + def dict(self, **kwargs: Any) -> Dict[str, Any]: + from llama_index.core.storage.docstore.utils import doc_to_json + + data = super().dict(**kwargs) + + try: + if self.obj is None: + data["obj"] = None + elif isinstance(self.obj, BaseNode): + data["obj"] = doc_to_json(self.obj) + elif isinstance(self.obj, BaseModel): + data["obj"] = self.obj.dict() + else: + data["obj"] = json.dumps(self.obj) + except Exception: + raise ValueError("IndexNode obj is not serializable: " + str(self.obj)) + + return data @classmethod def from_text_node( @@ -514,6 +534,28 @@ def from_text_node( index_id=index_id, ) + # TODO: return type here not supported by current mypy version + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore + output = super().from_dict(data, **kwargs) + + obj = data.get("obj", None) + parsed_obj = None + if isinstance(obj, str): + parsed_obj = TextNode(text=obj) + elif isinstance(obj, dict): + from llama_index.core.storage.docstore.utils import json_to_doc + + # check if its a node, else assume stringable + try: + parsed_obj = json_to_doc(obj) + except Exception: + parsed_obj = TextNode(text=str(obj)) + + output.obj = parsed_obj + + return output + @classmethod def get_type(cls) -> str: return ObjectType.INDEX diff --git a/llama-index-core/llama_index/core/vector_stores/utils.py b/llama-index-core/llama_index/core/vector_stores/utils.py index 13f04531c4fd2..66c42570be464 100644 --- a/llama-index-core/llama_index/core/vector_stores/utils.py +++ b/llama-index-core/llama_index/core/vector_stores/utils.py @@ -71,11 +71,11 @@ def metadata_dict_to_node(metadata: dict, text: Optional[str] = None) -> BaseNod node: BaseNode if node_type == IndexNode.class_name(): - node = IndexNode.parse_raw(node_json) + node = IndexNode.from_json(node_json) elif node_type == ImageNode.class_name(): - node = ImageNode.parse_raw(node_json) + node = ImageNode.from_json(node_json) else: - node = TextNode.parse_raw(node_json) + node = TextNode.from_json(node_json) if text is not None: node.set_content(text) diff --git a/llama-index-integrations/llms/llama-index-llms-litellm/BUILD b/llama-index-integrations/llms/llama-index-llms-litellm/BUILD index 0896ca890d8bf..a8f4940ed6efe 100644 --- a/llama-index-integrations/llms/llama-index-llms-litellm/BUILD +++ b/llama-index-integrations/llms/llama-index-llms-litellm/BUILD @@ -1,3 +1,7 @@ poetry_requirements( name="poetry", ) + +python_sources( + interpreter_constraints=["==3.9.*", "==3.10.*"], +) diff --git a/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml index 8165327c55ffb..fec49c9e66f13 100644 --- a/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml @@ -27,7 +27,7 @@ readme = "README.md" version = "0.1.1" [tool.poetry.dependencies] -python = ">=3.8.1,<3.12" +python = ">=3.9,<3.12" llama-index-core = "^0.10.1" litellm = "^1.18.13" diff --git a/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD index dabf212d7e716..5cd7615688ba0 100644 --- a/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD +++ b/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD @@ -1 +1,3 @@ -python_tests() +python_tests( + interpreter_constraints=["==3.9.*", "==3.10.*"], +) diff --git a/llama-index-legacy/llama_index/legacy/core/base_retriever.py b/llama-index-legacy/llama_index/legacy/core/base_retriever.py index 2b69aab840036..9cdfd8b6ec82f 100644 --- a/llama-index-legacy/llama_index/legacy/core/base_retriever.py +++ b/llama-index-legacy/llama_index/legacy/core/base_retriever.py @@ -77,6 +77,7 @@ def _retrieve_from_object( f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n", color="llama_pink", ) + if isinstance(obj, NodeWithScore): return [obj] elif isinstance(obj, BaseNode): @@ -149,7 +150,7 @@ def _handle_recursive_retrieval( node = n.node score = n.score or 1.0 if isinstance(node, IndexNode): - obj = self.object_map.get(node.index_id, None) + obj = node.obj or self.object_map.get(node.index_id, None) if obj is not None: if self._verbose: print_text( diff --git a/llama-index-legacy/llama_index/legacy/indices/base.py b/llama-index-legacy/llama_index/legacy/indices/base.py index 3482ec35d9575..416b5f1881881 100644 --- a/llama-index-legacy/llama_index/legacy/indices/base.py +++ b/llama-index-legacy/llama_index/legacy/indices/base.py @@ -67,7 +67,11 @@ def __init__( self._graph_store = self._storage_context.graph_store objects = objects or [] - self._object_map = {obj.index_id: obj.obj for obj in objects} + self._object_map = {} + for obj in objects: + self._object_map[obj.index_id] = obj.obj + obj.obj = None # clear the object avoid serialization issues + with self._service_context.callback_manager.as_trace("index_construction"): if index_struct is None: nodes = nodes or []