diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 0d6eabde8..fefe57a16 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -28,11 +28,6 @@ def initialize_components( system_app.register(DefaultExecutorFactory) system_app.register_instance(controller) - # Register global default RAGGraphFactory - # from dbgpt.graph.graph_factory import DefaultRAGGraphFactory - - # system_app.register(DefaultRAGGraphFactory) - from dbgpt.serve.agent.hub.controller import module_agent system_app.register_instance(module_agent) diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index 3db67c026..eba47df37 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -497,30 +497,6 @@ def get_document_chunks(self, request: ChunkQueryRequest): res.page = request.page return res - def async_knowledge_graph(self, chunk_docs, doc): - """async document extract triplets and save into graph db - Args: - - chunk_docs: List[Document] - - doc: KnowledgeDocumentEntity - """ - logger.info( - f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" - ) - try: - from dbgpt.rag.graph.graph_factory import RAGGraphFactory - - rag_engine = CFG.SYSTEM_APP.get_component( - ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory - ).create() - rag_engine.knowledge_graph(chunk_docs) - doc.status = SyncStatus.FINISHED.name - doc.result = "document build graph success" - except Exception as e: - doc.status = SyncStatus.FAILED.name - doc.result = "document build graph failed" + str(e) - logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}") - return knowledge_document_dao.update_knowledge_document(doc) - def async_doc_embedding(self, assembler, chunk_docs, doc): """async document embedding into vector db Args: diff --git a/dbgpt/rag/graph/graph_engine.py b/dbgpt/rag/graph/graph_engine.py deleted file mode 100644 index b00ca6695..000000000 --- a/dbgpt/rag/graph/graph_engine.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging -from typing import Any, Callable, List, Optional, Tuple - -from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter - -from dbgpt.rag.embedding import KnowledgeType -from dbgpt.rag.embedding.knowledge_type import get_knowledge_embedding -from dbgpt.rag.graph.index_struct import KG -from dbgpt.rag.graph.node import TextNode -from dbgpt.util import utils - -logger = logging.getLogger(__name__) - - -class RAGGraphEngine: - """Knowledge RAG Graph Engine. - Build a RAG Graph Client can extract triplets and insert into graph store. - Args: - knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value - extracting triplets. - knowledge_source (Optional[str]): - model_name (Optional[str]): llm model name - graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index - include_embeddings (bool): Whether to include embeddings in the index. - Defaults to False. - max_object_length (int): The maximum length of the object in a triplet. - Defaults to 128. - extract_triplet_fn (Optional[Callable]): The function to use for - extracting triplets. Defaults to None. - """ - - index_struct_cls = KG - - def __init__( - self, - knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value, - knowledge_source: Optional[str] = None, - text_splitter=None, - graph_store=None, - index_struct: Optional[KG] = None, - model_name: Optional[str] = None, - max_triplets_per_chunk: int = 10, - include_embeddings: bool = False, - max_object_length: int = 128, - extract_triplet_fn: Optional[Callable] = None, - **kwargs: Any, - ) -> None: - """Initialize params.""" - from llama_index.graph_stores import SimpleGraphStore - - # need to set parameters before building index in base class. - self.knowledge_source = knowledge_source - self.knowledge_type = knowledge_type - self.model_name = model_name - self.text_splitter = text_splitter - self.index_struct = index_struct - self.include_embeddings = include_embeddings - self.graph_store = graph_store or SimpleGraphStore() - # self.graph_store = graph_store - self.max_triplets_per_chunk = max_triplets_per_chunk - self._max_object_length = max_object_length - self._extract_triplet_fn = extract_triplet_fn - - def knowledge_graph(self, docs=None): - """knowledge docs into graph store""" - if not docs: - if self.text_splitter: - self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=2000, chunk_overlap=100 - ) - knowledge_source = get_knowledge_embedding( - knowledge_type=self.knowledge_type, - knowledge_source=self.knowledge_source, - text_splitter=self.text_splitter, - ) - docs = knowledge_source.read() - if self.index_struct is None: - self.index_struct = self._build_index_from_docs(docs) - - def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: - """Extract triplets from text by function or llm""" - if self._extract_triplet_fn is not None: - return self._extract_triplet_fn(text) - else: - return self._llm_extract_triplets(text) - - def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: - """Extract triplets from text by llm""" - import uuid - - from dbgpt.app.scene import ChatScene - from dbgpt.util.chat_util import llm_chat_response_nostream - - chat_param = { - "chat_session_id": uuid.uuid1(), - "current_user_input": text, - "select_param": "triplet", - "model_name": self.model_name, - } - loop = utils.get_or_create_event_loop() - triplets = loop.run_until_complete( - llm_chat_response_nostream( - ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param} - ) - ) - return triplets - - def _build_index_from_docs(self, documents: List[Document]) -> KG: - """Build the index from nodes. - Args:documents:List[Document] - """ - index_struct = self.index_struct_cls() - triplets = [] - for doc in documents: - trips = self._extract_triplets_task([doc], index_struct) - triplets.extend(trips) - print(triplets) - text_node = TextNode(text=doc.page_content, metadata=doc.metadata) - for triplet in triplets: - subj, _, obj = triplet - self.graph_store.upsert_triplet(*triplet) - index_struct.add_node([subj, obj], text_node) - return index_struct - - def search(self, query): - from dbgpt.rag.graph.graph_search import RAGGraphSearch - - graph_search = RAGGraphSearch(graph_engine=self) - return graph_search.search(query) - - def _extract_triplets_task(self, docs, index_struct): - triple_results = [] - for doc in docs: - import threading - - thread_id = threading.get_ident() - print(f"current thread-{thread_id} begin extract triplets task") - triplets = self._extract_triplets(doc.page_content) - if len(triplets) == 0: - triplets = [] - text_node = TextNode(text=doc.page_content, metadata=doc.metadata) - logger.info(f"extracted knowledge triplets: {triplets}") - print( - f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}" - ) - triple_results.extend(triplets) - return triple_results diff --git a/dbgpt/rag/graph/graph_factory.py b/dbgpt/rag/graph/graph_factory.py deleted file mode 100644 index ebdc23dd5..000000000 --- a/dbgpt/rag/graph/graph_factory.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Type - -from dbgpt.component import BaseComponent, ComponentType - - -class RAGGraphFactory(BaseComponent, ABC): - name = ComponentType.RAG_GRAPH_DEFAULT.value - - @abstractmethod - def create(self, model_name: str = None, embedding_cls: Type = None): - """Create RAG Graph Engine""" - - -class DefaultRAGGraphFactory(RAGGraphFactory): - def __init__( - self, system_app=None, default_model_name: str = None, **kwargs: Any - ) -> None: - super().__init__(system_app=system_app) - self._default_model_name = default_model_name - self.kwargs = kwargs - from dbgpt.rag.graph.graph_engine import RAGGraphEngine - - self.rag_engine = RAGGraphEngine(model_name="proxyllm") - - def init_app(self, system_app): - pass - - def create(self, model_name: str = None, rag_cls: Type = None): - if not model_name: - model_name = self._default_model_name - - return self.rag_engine diff --git a/dbgpt/rag/graph/graph_search.py b/dbgpt/rag/graph/graph_search.py deleted file mode 100644 index d485afc66..000000000 --- a/dbgpt/rag/graph/graph_search.py +++ /dev/null @@ -1,198 +0,0 @@ -import logging -import os -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Optional, Set - -from langchain.schema import Document - -from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode -from dbgpt.rag.graph.search import BaseSearch, SearchMode - -logger = logging.getLogger(__name__) -DEFAULT_NODE_SCORE = 1000.0 -GLOBAL_EXPLORE_NODE_LIMIT = 3 -REL_TEXT_LIMIT = 30 - - -class RAGGraphSearch(BaseSearch): - """RAG Graph Search. - - args: - graph RAGGraphEngine. - model_name (str): model name - (see :ref:`Prompt-Templates`). - text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt - (see :ref:`Prompt-Templates`). - max_keywords_per_query (int): Maximum number of keywords to extract from query. - num_chunks_per_query (int): Maximum number of text chunks to query. - search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD - embeddings, or both to find relevant triplets. Should be one of "keyword", - "embedding", or "hybrid". - graph_store_query_depth (int): The depth of the graph store query. - extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback. - """ - - def __init__( - self, - graph_engine, - model_name: str = None, - max_keywords_per_query: int = 10, - num_chunks_per_query: int = 10, - search_mode: Optional[SearchMode] = SearchMode.KEYWORD, - graph_store_query_depth: int = 2, - extract_subject_entities_fn: Optional[Callable] = None, - **kwargs: Any, - ) -> None: - """Initialize params.""" - from dbgpt.rag.graph.graph_engine import RAGGraphEngine - - self.graph_engine: RAGGraphEngine = graph_engine - self.model_name = model_name or self.graph_engine.model_name - self._index_struct = self.graph_engine.index_struct - self.max_keywords_per_query = max_keywords_per_query - self.num_chunks_per_query = num_chunks_per_query - self._search_mode = search_mode - - self._graph_store = self.graph_engine.graph_store - self.graph_store_query_depth = graph_store_query_depth - self._verbose = kwargs.get("verbose", False) - refresh_schema = kwargs.get("refresh_schema", False) - self.extract_subject_entities_fn = extract_subject_entities_fn - self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5) - try: - self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema) - except NotImplementedError: - self._graph_schema = "" - except Exception as e: - logger.warn(f"can not to find graph schema: {e}") - self._graph_schema = "" - - async def _extract_subject_entities(self, query_str: str) -> Set[str]: - """extract subject entities.""" - if self.extract_subject_entities_fn is not None: - return await self.extract_subject_entities_fn(query_str) - else: - return await self._extract_entities_by_llm(query_str) - - async def _extract_entities_by_llm(self, text: str) -> Set[str]: - """extract subject entities from text by llm""" - import uuid - - from dbgpt.app.scene import ChatScene - from dbgpt.util.chat_util import llm_chat_response_nostream - - chat_param = { - "chat_session_id": uuid.uuid1(), - "current_user_input": text, - "select_param": "entity", - "model_name": self.model_name, - } - # loop = util.get_or_create_event_loop() - # entities = loop.run_until_complete( - # llm_chat_response_nostream( - # ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} - # ) - # ) - return await llm_chat_response_nostream( - ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} - ) - - async def _search( - self, - query_str: str, - ) -> List[Document]: - """Get nodes for response.""" - node_visited = set() - keywords = await self._extract_subject_entities(query_str) - print(f"extract entities: {keywords}\n") - rel_texts = [] - cur_rel_map = {} - chunk_indices_count: Dict[str, int] = defaultdict(int) - if self._search_mode != SearchMode.EMBEDDING: - for keyword in keywords: - keyword = keyword.lower() - subjs = set((keyword,)) - # node_ids = self._index_struct.search_node_by_keyword(keyword) - # for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]: - # if node_id in node_visited: - # continue - # - # # if self._include_text: - # # chunk_indices_count[node_id] += 1 - # - # node_visited.add(node_id) - - rel_map = self._graph_store.get_rel_map( - list(subjs), self.graph_store_query_depth - ) - logger.debug(f"rel_map: {rel_map}") - - if not rel_map: - continue - rel_texts.extend( - [ - str(rel_obj) - for rel_objs in rel_map.values() - for rel_obj in rel_objs - ] - ) - cur_rel_map.update(rel_map) - - sorted_nodes_with_scores = [] - if not rel_texts: - logger.info("> No relationships found, returning nodes found by keywords.") - if len(sorted_nodes_with_scores) == 0: - logger.info("> No nodes found by keywords, returning empty response.") - return [Document(page_content="No relationships found.")] - - # add relationships as Node - # TODO: make initial text customizable - rel_initial_text = ( - f"The following are knowledge sequence in max depth" - f" {self.graph_store_query_depth} " - f"in the form of directed graph like:\n" - f"`subject -[predicate]->, object, <-[predicate_next_hop]-," - f" object_next_hop ...`" - ) - rel_info = [rel_initial_text] + rel_texts - rel_node_info = { - "kg_rel_texts": rel_texts, - "kg_rel_map": cur_rel_map, - } - if self._graph_schema != "": - rel_node_info["kg_schema"] = {"schema": self._graph_schema} - rel_info_text = "\n".join( - [ - str(item) - for sublist in rel_info - for item in (sublist if isinstance(sublist, list) else [sublist]) - ] - ) - if self._verbose: - print(f"KG context:\n{rel_info_text}\n", color="blue") - rel_text_node = TextNode( - text=rel_info_text, - metadata=rel_node_info, - excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"], - excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"], - ) - # this node is constructed from rel_texts, give high confidence to avoid cutoff - sorted_nodes_with_scores.append( - NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE) - ) - docs = [ - Document(page_content=node.text, metadata=node.metadata) - for node in sorted_nodes_with_scores - ] - return docs - - def _get_metadata_for_response( - self, nodes: List[BaseNode] - ) -> Optional[Dict[str, Any]]: - """Get metadata for response.""" - for node in nodes: - if node.metadata is None or "kg_rel_map" not in node.metadata: - continue - return node.metadata - raise ValueError("kg_rel_map must be found in at least one Node.") diff --git a/dbgpt/rag/graph/index_struct.py b/dbgpt/rag/graph/index_struct.py deleted file mode 100644 index 010d1dcea..000000000 --- a/dbgpt/rag/graph/index_struct.py +++ /dev/null @@ -1,258 +0,0 @@ -"""Data structures. - -Nodes are decoupled from the indices. - -""" - -import uuid -from abc import abstractmethod -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Sequence, Set - -from dataclasses_json import DataClassJsonMixin - -from dbgpt.rag.graph.index_type import IndexStructType -from dbgpt.rag.graph.node import BaseNode, TextNode - -# TODO: legacy backport of old Node class -Node = TextNode - - -@dataclass -class IndexStruct(DataClassJsonMixin): - """A base data struct for a LlamaIndex.""" - - index_id: str = field(default_factory=lambda: str(uuid.uuid4())) - summary: Optional[str] = None - - def get_summary(self) -> str: - """Get text summary.""" - if self.summary is None: - raise ValueError("summary field of the index_struct not set.") - return self.summary - - @classmethod - @abstractmethod - def get_type(cls): - """Get index struct type.""" - - -@dataclass -class IndexGraph(IndexStruct): - """A graph representing the tree-structured index.""" - - # mapping from index in tree to Node doc id. - all_nodes: Dict[int, str] = field(default_factory=dict) - root_nodes: Dict[int, str] = field(default_factory=dict) - node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict) - - @property - def node_id_to_index(self) -> Dict[str, int]: - """Map from node id to index.""" - return {node_id: index for index, node_id in self.all_nodes.items()} - - @property - def size(self) -> int: - """Get the size of the graph.""" - return len(self.all_nodes) - - def get_index(self, node: BaseNode) -> int: - """Get index of node.""" - return self.node_id_to_index[node.node_id] - - def insert( - self, - node: BaseNode, - index: Optional[int] = None, - children_nodes: Optional[Sequence[BaseNode]] = None, - ) -> None: - """Insert node.""" - index = index or self.size - node_id = node.node_id - - self.all_nodes[index] = node_id - - if children_nodes is None: - children_nodes = [] - children_ids = [n.node_id for n in children_nodes] - self.node_id_to_children_ids[node_id] = children_ids - - def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]: - """Get children nodes.""" - if parent_node is None: - return self.root_nodes - else: - parent_id = parent_node.node_id - children_ids = self.node_id_to_children_ids[parent_id] - return { - self.node_id_to_index[child_id]: child_id for child_id in children_ids - } - - def insert_under_parent( - self, - node: BaseNode, - parent_node: Optional[BaseNode], - new_index: Optional[int] = None, - ) -> None: - """Insert under parent node.""" - new_index = new_index or self.size - if parent_node is None: - self.root_nodes[new_index] = node.node_id - self.node_id_to_children_ids[node.node_id] = [] - else: - if parent_node.node_id not in self.node_id_to_children_ids: - self.node_id_to_children_ids[parent_node.node_id] = [] - self.node_id_to_children_ids[parent_node.node_id].append(node.node_id) - - self.all_nodes[new_index] = node.node_id - - @classmethod - def get_type(cls) -> IndexStructType: - """Get type.""" - return IndexStructType.TREE - - -@dataclass -class KeywordTable(IndexStruct): - """A table of keywords mapping keywords to text chunks.""" - - table: Dict[str, Set[str]] = field(default_factory=dict) - - def add_node(self, keywords: List[str], node: BaseNode) -> None: - """Add text to table.""" - for keyword in keywords: - if keyword not in self.table: - self.table[keyword] = set() - self.table[keyword].add(node.node_id) - - @property - def node_ids(self) -> Set[str]: - """Get all node ids.""" - return set.union(*self.table.values()) - - @property - def keywords(self) -> Set[str]: - """Get all keywords in the table.""" - return set(self.table.keys()) - - @property - def size(self) -> int: - """Get the size of the table.""" - return len(self.table) - - @classmethod - def get_type(cls) -> IndexStructType: - """Get type.""" - return IndexStructType.KEYWORD_TABLE - - -@dataclass -class IndexList(IndexStruct): - """A list of documents.""" - - nodes: List[str] = field(default_factory=list) - - def add_node(self, node: BaseNode) -> None: - """Add text to table, return current position in list.""" - # don't worry about child indices for now, nodes are all in order - self.nodes.append(node.node_id) - - @classmethod - def get_type(cls) -> IndexStructType: - """Get type.""" - return IndexStructType.LIST - - -@dataclass -class IndexDict(IndexStruct): - """A simple dictionary of documents.""" - - # TODO: slightly deprecated, should likely be a list or set now - # mapping from vector store id to node doc_id - nodes_dict: Dict[str, str] = field(default_factory=dict) - - # TODO: deprecated, not used - # mapping from node doc_id to vector store id - doc_id_dict: Dict[str, List[str]] = field(default_factory=dict) - - # TODO: deprecated, not used - # this should be empty for all other indices - embeddings_dict: Dict[str, List[float]] = field(default_factory=dict) - - def add_node( - self, - node: BaseNode, - text_id: Optional[str] = None, - ) -> str: - """Add text to table, return current position in list.""" - # # don't worry about child indices for now, nodes are all in order - # self.nodes_dict[int_id] = node - vector_id = text_id if text_id is not None else node.node_id - self.nodes_dict[vector_id] = node.node_id - - return vector_id - - def delete(self, doc_id: str) -> None: - """Delete a Node.""" - del self.nodes_dict[doc_id] - - @classmethod - def get_type(cls) -> IndexStructType: - """Get type.""" - return IndexStructType.VECTOR_STORE - - -@dataclass -class KG(IndexStruct): - """A table of keywords mapping keywords to text chunks.""" - - # Unidirectional - - # table of keywords to node ids - table: Dict[str, Set[str]] = field(default_factory=dict) - - # TODO: legacy attribute, remove in future releases - rel_map: Dict[str, List[List[str]]] = field(default_factory=dict) - - # TBD, should support vector store, now we just persist the embedding memory - # maybe chainable abstractions for *_stores could be designed - embedding_dict: Dict[str, List[float]] = field(default_factory=dict) - - @property - def node_ids(self) -> Set[str]: - """Get all node ids.""" - return set.union(*self.table.values()) - - def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None: - """Add embedding to dict.""" - self.embedding_dict[triplet_str] = embedding - - def add_node(self, keywords: List[str], node: BaseNode) -> None: - """Add text to table.""" - node_id = node.node_id - for keyword in keywords: - keyword = keyword.lower() - if keyword not in self.table: - self.table[keyword] = set() - self.table[keyword].add(node_id) - - def search_node_by_keyword(self, keyword: str) -> List[str]: - """Search for nodes by keyword.""" - if keyword not in self.table: - return [] - return list(self.table[keyword]) - - @classmethod - def get_type(cls) -> IndexStructType: - """Get type.""" - return IndexStructType.KG - - -@dataclass -class EmptyIndexStruct(IndexStruct): - """Empty index.""" - - @classmethod - def get_type(cls) -> IndexStructType: - """Get type.""" - return IndexStructType.EMPTY diff --git a/dbgpt/rag/graph/index_type.py b/dbgpt/rag/graph/index_type.py deleted file mode 100644 index 939066be9..000000000 --- a/dbgpt/rag/graph/index_type.py +++ /dev/null @@ -1,48 +0,0 @@ -"""IndexStructType class.""" - -from enum import Enum - - -class IndexStructType(str, Enum): - """Index struct type. Identifier for a "type" of index. - - Attributes: - TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices. - LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices. - KEYWORD_TABLE ("keyword_table"): Keyword table index. See - :ref:`Ref-Indices-Table` - for keyword table indices. - DICT ("dict"): Faiss Vector Store Index. See - :ref:`Ref-Indices-VectorStore` - for more information on the faiss vector store index. - SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See - :ref:`Ref-Indices-VectorStore` - for more information on the simple vector store index. - KG ("kg"): Knowledge Graph index. - See :ref:`Ref-Indices-Knowledge-Graph` for KG indices. - DOCUMENT_SUMMARY ("document_summary"): Document Summary Index. - See :ref:`Ref-Indices-Document-Summary` for Summary Indices. - - """ - - # TODO: refactor so these are properties on the base class - - NODE = "node" - TREE = "tree" - LIST = "list" - KEYWORD_TABLE = "keyword_table" - - DICT = "dict" - # simple - SIMPLE_DICT = "simple_dict" - # for KG index - KG = "kg" - SIMPLE_KG = "simple_kg" - NEBULAGRAPH = "nebulagraph" - FALKORDB = "falkordb" - - # EMPTY - EMPTY = "empty" - COMPOSITE = "composite" - - DOCUMENT_SUMMARY = "document_summary" diff --git a/dbgpt/rag/graph/kv_index.py b/dbgpt/rag/graph/kv_index.py deleted file mode 100644 index 963b1da0e..000000000 --- a/dbgpt/rag/graph/kv_index.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import List, Optional - -from llama_index.data_structs.data_structs import IndexStruct -from llama_index.storage.index_store.utils import ( - index_struct_to_json, - json_to_index_struct, -) -from llama_index.storage.kvstore.types import BaseKVStore - -DEFAULT_NAMESPACE = "index_store" - - -class KVIndexStore: - """Key-Value Index store. - - Args: - kvstore (BaseKVStore): key-value store - namespace (str): namespace for the index store - - """ - - def __init__(self, kvstore: BaseKVStore, namespace: Optional[str] = None) -> None: - """Init a KVIndexStore.""" - self._kvstore = kvstore - self._namespace = namespace or DEFAULT_NAMESPACE - self._collection = f"{self._namespace}/data" - - def add_index_struct(self, index_struct: IndexStruct) -> None: - """Add an index struct. - - Args: - index_struct (IndexStruct): index struct - - """ - key = index_struct.index_id - data = index_struct_to_json(index_struct) - self._kvstore.put(key, data, collection=self._collection) - - def delete_index_struct(self, key: str) -> None: - """Delete an index struct. - - Args: - key (str): index struct key - - """ - self._kvstore.delete(key, collection=self._collection) - - def get_index_struct( - self, struct_id: Optional[str] = None - ) -> Optional[IndexStruct]: - """Get an index struct. - - Args: - struct_id (Optional[str]): index struct id - - """ - if struct_id is None: - structs = self.index_structs() - assert len(structs) == 1 - return structs[0] - else: - json = self._kvstore.get(struct_id, collection=self._collection) - if json is None: - return None - return json_to_index_struct(json) - - def index_structs(self) -> List[IndexStruct]: - """Get all index structs. - - Returns: - List[IndexStruct]: index structs - - """ - jsons = self._kvstore.get_all(collection=self._collection) - return [json_to_index_struct(json) for json in jsons.values()] diff --git a/dbgpt/rag/graph/node.py b/dbgpt/rag/graph/node.py deleted file mode 100644 index aef3e4c30..000000000 --- a/dbgpt/rag/graph/node.py +++ /dev/null @@ -1,570 +0,0 @@ -"""Base schema for data structures.""" -import json -import textwrap -import uuid -from abc import abstractmethod -from enum import Enum, auto -from hashlib import sha256 -from typing import Any, Dict, List, Optional, Union - -from langchain.schema import Document -from typing_extensions import Self - -from dbgpt._private.pydantic import BaseModel, Field, root_validator - -DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" -DEFAULT_METADATA_TMPL = "{key}: {value}" -# NOTE: for pretty printing -TRUNCATE_LENGTH = 350 -WRAP_WIDTH = 70 - - -class BaseComponent(BaseModel): - """Base component object to caputure class names.""" - - """reference llama-index""" - - @classmethod - @abstractmethod - def class_name(cls) -> str: - """Get class name.""" - - def to_dict(self, **kwargs: Any) -> Dict[str, Any]: - data = self.dict(**kwargs) - data["class_name"] = self.class_name() - return data - - def to_json(self, **kwargs: Any) -> str: - data = self.to_dict(**kwargs) - return json.dumps(data) - - # TODO: return type here not supported by current mypy version - @classmethod - def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore - if isinstance(kwargs, dict): - data.update(kwargs) - - data.pop("class_name", None) - return cls(**data) - - @classmethod - def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore - data = json.loads(data_str) - return cls.from_dict(data, **kwargs) - - -class NodeRelationship(str, Enum): - """Node relationships used in `BaseNode` class. - - Attributes: - SOURCE: The node is the source document. - PREVIOUS: The node is the previous node in the document. - NEXT: The node is the next node in the document. - PARENT: The node is the parent node in the document. - CHILD: The node is a child node in the document. - - """ - - SOURCE = auto() - PREVIOUS = auto() - NEXT = auto() - PARENT = auto() - CHILD = auto() - - -class ObjectType(str, Enum): - TEXT = auto() - IMAGE = auto() - INDEX = auto() - DOCUMENT = auto() - - -class MetadataMode(str, Enum): - ALL = auto() - EMBED = auto() - LLM = auto() - NONE = auto() - - -class RelatedNodeInfo(BaseComponent): - node_id: str - node_type: Optional[ObjectType] = None - metadata: Dict[str, Any] = Field(default_factory=dict) - hash: Optional[str] = None - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "RelatedNodeInfo" - - -RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]] - - -# Node classes for indexes -class BaseNode(BaseComponent): - """Base node Object. - - Generic abstract interface for retrievable nodes - - """ - - class Config: - allow_population_by_field_name = True - - id_: str = Field( - default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." - ) - embedding: Optional[List[float]] = Field( - default=None, description="Embedding of the node." - ) - - """" - metadata fields - - injected as part of the text shown to LLMs as context - - injected as part of the text for generating embeddings - - used by vector DBs for metadata filtering - - """ - metadata: Dict[str, Any] = Field( - default_factory=dict, - description="A flat dictionary of metadata fields", - alias="extra_info", - ) - excluded_embed_metadata_keys: List[str] = Field( - default_factory=list, - description="Metadata keys that are exluded from text for the embed model.", - ) - excluded_llm_metadata_keys: List[str] = Field( - default_factory=list, - description="Metadata keys that are exluded from text for the LLM.", - ) - relationships: Dict[NodeRelationship, RelatedNodeType] = Field( - default_factory=dict, - description="A mapping of relationships to other node information.", - ) - hash: str = Field(default="", description="Hash of the node content.") - - @classmethod - @abstractmethod - def get_type(cls) -> str: - """Get Object type.""" - - @abstractmethod - def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: - """Get object content.""" - - @abstractmethod - def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: - """Metadata string.""" - - @abstractmethod - def set_content(self, value: Any) -> None: - """Set the content of the node.""" - - @property - def node_id(self) -> str: - return self.id_ - - @node_id.setter - def node_id(self, value: str) -> None: - self.id_ = value - - @property - def source_node(self) -> Optional[RelatedNodeInfo]: - """Source object node. - - Extracted from the relationships field. - - """ - if NodeRelationship.SOURCE not in self.relationships: - return None - - relation = self.relationships[NodeRelationship.SOURCE] - if isinstance(relation, list): - raise ValueError("Source object must be a single RelatedNodeInfo object") - return relation - - @property - def prev_node(self) -> Optional[RelatedNodeInfo]: - """Prev node.""" - if NodeRelationship.PREVIOUS not in self.relationships: - return None - - relation = self.relationships[NodeRelationship.PREVIOUS] - if not isinstance(relation, RelatedNodeInfo): - raise ValueError("Previous object must be a single RelatedNodeInfo object") - return relation - - @property - def next_node(self) -> Optional[RelatedNodeInfo]: - """Next node.""" - if NodeRelationship.NEXT not in self.relationships: - return None - - relation = self.relationships[NodeRelationship.NEXT] - if not isinstance(relation, RelatedNodeInfo): - raise ValueError("Next object must be a single RelatedNodeInfo object") - return relation - - @property - def parent_node(self) -> Optional[RelatedNodeInfo]: - """Parent node.""" - if NodeRelationship.PARENT not in self.relationships: - return None - - relation = self.relationships[NodeRelationship.PARENT] - if not isinstance(relation, RelatedNodeInfo): - raise ValueError("Parent object must be a single RelatedNodeInfo object") - return relation - - @property - def child_nodes(self) -> Optional[List[RelatedNodeInfo]]: - """Child nodes.""" - if NodeRelationship.CHILD not in self.relationships: - return None - - relation = self.relationships[NodeRelationship.CHILD] - if not isinstance(relation, list): - raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") - return relation - - @property - def ref_doc_id(self) -> Optional[str]: - """Deprecated: Get ref doc id.""" - source_node = self.source_node - if source_node is None: - return None - return source_node.node_id - - @property - def extra_info(self) -> Dict[str, Any]: - """TODO: DEPRECATED: Extra info.""" - return self.metadata - - def __str__(self) -> str: - source_text_truncated = truncate_text( - self.get_content().strip(), TRUNCATE_LENGTH - ) - source_text_wrapped = textwrap.fill( - f"Text: {source_text_truncated}\n", width=WRAP_WIDTH - ) - return f"Node ID: {self.node_id}\n{source_text_wrapped}" - - def truncate_text(text: str, max_length: int) -> str: - """Truncate text to a maximum length.""" - if len(text) <= max_length: - return text - return text[: max_length - 3] + "..." - - def get_embedding(self) -> List[float]: - """Get embedding. - - Errors if embedding is None. - - """ - if self.embedding is None: - raise ValueError("embedding not set.") - return self.embedding - - def as_related_node_info(self) -> RelatedNodeInfo: - """Get node as RelatedNodeInfo.""" - return RelatedNodeInfo( - node_id=self.node_id, metadata=self.metadata, hash=self.hash - ) - - -class TextNode(BaseNode): - text: str = Field(default="", description="Text content of the node.") - start_char_idx: Optional[int] = Field( - default=None, description="Start char index of the node." - ) - end_char_idx: Optional[int] = Field( - default=None, description="End char index of the node." - ) - text_template: str = Field( - default=DEFAULT_TEXT_NODE_TMPL, - description=( - "Template for how text is formatted, with {content} and " - "{metadata_str} placeholders." - ), - ) - metadata_template: str = Field( - default=DEFAULT_METADATA_TMPL, - description=( - "Template for how metadata is formatted, with {key} and " - "{value} placeholders." - ), - ) - metadata_seperator: str = Field( - default="\n", - description="Seperator between metadata fields when converting to string.", - ) - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "TextNode" - - @root_validator - def _check_hash(cls, values: dict) -> dict: - """Generate a hash to represent the node.""" - text = values.get("text", "") - metadata = values.get("metadata", {}) - doc_identity = str(text) + str(metadata) - values["hash"] = str( - sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest() - ) - return values - - @classmethod - def get_type(cls) -> str: - """Get Object type.""" - return ObjectType.TEXT - - def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: - """Get object content.""" - metadata_str = self.get_metadata_str(mode=metadata_mode).strip() - if not metadata_str: - return self.text - - return self.text_template.format( - content=self.text, metadata_str=metadata_str - ).strip() - - def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: - """metadata info string.""" - if mode == MetadataMode.NONE: - return "" - - usable_metadata_keys = set(self.metadata.keys()) - if mode == MetadataMode.LLM: - for key in self.excluded_llm_metadata_keys: - if key in usable_metadata_keys: - usable_metadata_keys.remove(key) - elif mode == MetadataMode.EMBED: - for key in self.excluded_embed_metadata_keys: - if key in usable_metadata_keys: - usable_metadata_keys.remove(key) - - return self.metadata_seperator.join( - [ - self.metadata_template.format(key=key, value=str(value)) - for key, value in self.metadata.items() - if key in usable_metadata_keys - ] - ) - - def set_content(self, value: str) -> None: - """Set the content of the node.""" - self.text = value - - def get_node_info(self) -> Dict[str, Any]: - """Get node info.""" - return {"start": self.start_char_idx, "end": self.end_char_idx} - - def get_text(self) -> str: - return self.get_content(metadata_mode=MetadataMode.NONE) - - @property - def node_info(self) -> Dict[str, Any]: - """Deprecated: Get node info.""" - return self.get_node_info() - - -# TODO: legacy backport of old Node class -Node = TextNode - - -class ImageNode(TextNode): - """Node with image.""" - - # TODO: store reference instead of actual image - # base64 encoded image str - image: Optional[str] = None - - @classmethod - def get_type(cls) -> str: - return ObjectType.IMAGE - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "ImageNode" - - -class IndexNode(TextNode): - """Node with reference to any object. - - This can include other indices, query engines, retrievers. - - This can also include other nodes (though this is overlapping with `relationships` - on the Node class). - - """ - - index_id: str - - @classmethod - def from_text_node( - cls, - node: TextNode, - index_id: str, - ) -> "IndexNode": - """Create index node from text node.""" - # copy all attributes from text node, add index id - return cls( - **node.dict(), - index_id=index_id, - ) - - @classmethod - def get_type(cls) -> str: - return ObjectType.INDEX - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "IndexNode" - - -class NodeWithScore(BaseComponent): - node: BaseNode - score: Optional[float] = None - - def __str__(self) -> str: - return f"{self.node}\nScore: {self.score: 0.3f}\n" - - def get_score(self, raise_error: bool = False) -> float: - """Get score.""" - if self.score is None: - if raise_error: - raise ValueError("Score not set.") - else: - return 0.0 - else: - return self.score - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "NodeWithScore" - - ##### pass through methods to BaseNode ##### - @property - def node_id(self) -> str: - return self.node.node_id - - @property - def id_(self) -> str: - return self.node.id_ - - @property - def text(self) -> str: - if isinstance(self.node, TextNode): - return self.node.text - else: - raise ValueError("Node must be a TextNode to get text.") - - @property - def metadata(self) -> Dict[str, Any]: - return self.node.metadata - - @property - def embedding(self) -> Optional[List[float]]: - return self.node.embedding - - def get_text(self) -> str: - if isinstance(self.node, TextNode): - return self.node.get_text() - else: - raise ValueError("Node must be a TextNode to get text.") - - def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: - return self.node.get_content(metadata_mode=metadata_mode) - - def get_embedding(self) -> List[float]: - return self.node.get_embedding() - - -# Document Classes for Readers - - -class Document(TextNode): - """Generic interface for a data document. - - This document connects to data sources. - - """ - - # TODO: A lot of backwards compatibility logic here, clean up - id_: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique ID of the node.", - alias="doc_id", - ) - - _compat_fields = {"doc_id": "id_", "extra_info": "metadata"} - - @classmethod - def get_type(cls) -> str: - """Get Document type.""" - return ObjectType.DOCUMENT - - @property - def doc_id(self) -> str: - """Get document ID.""" - return self.id_ - - def __str__(self) -> str: - source_text_truncated = truncate_text( - self.get_content().strip(), TRUNCATE_LENGTH - ) - source_text_wrapped = textwrap.fill( - f"Text: {source_text_truncated}\n", width=WRAP_WIDTH - ) - return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" - - def get_doc_id(self) -> str: - """TODO: Deprecated: Get document ID.""" - return self.id_ - - def __setattr__(self, name: str, value: object) -> None: - if name in self._compat_fields: - name = self._compat_fields[name] - super().__setattr__(name, value) - - def to_langchain_format(self) -> Document: - """Convert struct to LangChain document format.""" - metadata = self.metadata or {} - return Document(page_content=self.text, metadata=metadata) - - @classmethod - def from_langchain_format(cls, doc: Document) -> "Document": - """Convert struct from LangChain document format.""" - return cls(text=doc.page_content, metadata=doc.metadata) - - @classmethod - def example(cls) -> "Document": - document = Document( - text="", - metadata={"filename": "README.md", "category": "codebase"}, - ) - return document - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "Document" - - -class ImageDocument(Document): - """Data document containing an image.""" - - # base64 encoded image str - image: Optional[str] = None - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "ImageDocument" diff --git a/dbgpt/rag/graph/search.py b/dbgpt/rag/graph/search.py deleted file mode 100644 index 297620b00..000000000 --- a/dbgpt/rag/graph/search.py +++ /dev/null @@ -1,44 +0,0 @@ -from abc import ABC, abstractmethod -from enum import Enum - - -class SearchMode(str, Enum): - """Query mode enum for Knowledge Graphs. - - Can be passed as the enum struct, or as the underlying string. - - Attributes: - KEYWORD ("keyword"): Default query mode, using keywords to find triplets. - EMBEDDING ("embedding"): Embedding mode, using embeddings to find - similar triplets. - HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings - to find relevant triplets. - """ - - KEYWORD = "keyword" - EMBEDDING = "embedding" - HYBRID = "hybrid" - - -class BaseSearch(ABC): - """Base Search.""" - - async def search(self, query: str): - """Retrieve nodes given query. - - Args: - query (QueryType): Either a query string or - a QueryBundle object. - - """ - # if isinstance(query, str): - return await self._search(query) - - @abstractmethod - async def _search(self, query: str): - """search nodes given query. - - Implemented by the user. - - """ - pass diff --git a/examples/awel/simple_rag_embedding_example.py b/examples/awel/simple_rag_embedding_example.py index 268d4bcb2..a2a6f961b 100644 --- a/examples/awel/simple_rag_embedding_example.py +++ b/examples/awel/simple_rag_embedding_example.py @@ -1,40 +1,32 @@ -import asyncio import os from typing import Dict, List +from pydantic import BaseModel, Field + from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH -from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource -from dbgpt.rag.chunk import Chunk +from dbgpt.core.awel import DAG, HttpTrigger, MapOperator from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory +from dbgpt.rag.knowledge.base import KnowledgeType from dbgpt.rag.operator.knowledge import KnowledgeOperator from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector """AWEL: Simple rag embedding operator example - - pre-requirements: - set your file path in your example code. + Examples: + pre-requirements: + python examples/awel/simple_rag_embedding_example.py ..code-block:: shell - python examples/awel/simple_rag_embedding_example.py + curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "url": "https://docs.dbgpt.site/docs/awel" + }' """ -def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict: - """context Join function for JoinOperator. - - Args: - context_dict (Dict): context dict - chunks (List[Chunk]): chunks - Returns: - Dict: context dict - """ - context_dict["context"] = "\n".join([chunk.content for chunk in chunks]) - return context_dict - - -def _create_vector_connector(): +def _create_vector_connector() -> VectorStoreConnector: """Create vector connector.""" return VectorStoreConnector.from_default( "Chroma", @@ -48,6 +40,22 @@ def _create_vector_connector(): ) +class TriggerReqBody(BaseModel): + url: str = Field(..., description="url") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> Dict: + params = { + "url": input_value.url, + } + print(f"Receive input value: {input_value}") + return params + + class ResultOperator(MapOperator): """The Result Operator.""" @@ -61,26 +69,31 @@ async def map(self, chunks: List) -> str: with DAG("simple_sdk_rag_embedding_example") as dag: - knowledge_operator = KnowledgeOperator() + trigger = HttpTrigger( + "/examples/rag/embedding", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL) vector_connector = _create_vector_connector() - input_task = InputOperator(input_source=SimpleCallDataInputSource()) - file_path_parser = MapOperator(map_function=lambda x: x["file_path"]) + url_parser_operator = MapOperator(map_function=lambda x: x["url"]) embedding_operator = EmbeddingAssemblerOperator( vector_store_connector=vector_connector, ) output_task = ResultOperator() ( - input_task - >> file_path_parser + trigger + >> request_handle_task + >> url_parser_operator >> knowledge_operator >> embedding_operator >> output_task ) if __name__ == "__main__": - input_data = { - "data": { - "file_path": "docs/docs/awel.md", - } - } - output = asyncio.run(output_task.call(call_data=input_data)) + if dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag], port=5555) + else: + pass