From 7446817340a7041284ec181748f0bc4bbf9e39c3 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 7 Mar 2024 23:27:43 +0800 Subject: [PATCH] chore: Add pylint for DB-GPT rag lib (#1267) --- .mypy.ini | 40 +++++ Makefile | 3 + dbgpt/_private/config.py | 25 ++- dbgpt/component.py | 2 +- dbgpt/core/awel/operators/base.py | 2 +- dbgpt/datasource/rdbms/base.py | 4 +- dbgpt/datasource/rdbms/conn_clickhouse.py | 4 +- dbgpt/datasource/rdbms/conn_doris.py | 4 +- dbgpt/datasource/rdbms/conn_postgresql.py | 4 +- dbgpt/datasource/rdbms/conn_sqlite.py | 4 +- dbgpt/datasource/rdbms/conn_starrocks.py | 4 +- dbgpt/rag/__init__.py | 1 + dbgpt/rag/chunk.py | 51 +++--- dbgpt/rag/chunk_manager.py | 33 ++-- dbgpt/rag/embedding/__init__.py | 16 +- dbgpt/rag/embedding/embedding_factory.py | 35 ++++- dbgpt/rag/embedding/embeddings.py | 62 +++++--- dbgpt/rag/extractor/__init__.py | 5 + dbgpt/rag/extractor/base.py | 11 +- dbgpt/rag/extractor/summary.py | 77 +++++---- dbgpt/rag/graph/__init__.py | 1 + dbgpt/rag/knowledge/__init__.py | 29 ++++ dbgpt/rag/knowledge/base.py | 80 ++++++---- dbgpt/rag/knowledge/csv.py | 27 ++-- dbgpt/rag/knowledge/docx.py | 20 ++- dbgpt/rag/knowledge/factory.py | 73 +++++---- dbgpt/rag/knowledge/html.py | 26 +-- dbgpt/rag/knowledge/json.py | 1 + dbgpt/rag/knowledge/markdown.py | 22 ++- dbgpt/rag/knowledge/pdf.py | 32 ++-- dbgpt/rag/knowledge/pptx.py | 30 +++- dbgpt/rag/knowledge/string.py | 22 ++- dbgpt/rag/knowledge/txt.py | 20 ++- dbgpt/rag/knowledge/url.py | 23 ++- dbgpt/rag/operators/__init__.py | 19 +++ dbgpt/rag/operators/datasource.py | 6 + dbgpt/rag/operators/db_schema.py | 15 +- dbgpt/rag/operators/embedding.py | 17 +- dbgpt/rag/operators/knowledge.py | 10 +- dbgpt/rag/operators/rerank.py | 22 +-- dbgpt/rag/operators/rewrite.py | 19 ++- dbgpt/rag/operators/schema_linking.py | 20 ++- dbgpt/rag/operators/summary.py | 18 ++- dbgpt/rag/retriever/__init__.py | 18 +++ dbgpt/rag/retriever/base.py | 30 +++- dbgpt/rag/retriever/db_schema.py | 52 ++++-- dbgpt/rag/retriever/embedding.py | 47 ++++-- dbgpt/rag/retriever/rerank.py | 57 ++++--- dbgpt/rag/retriever/rewrite.py | 36 +++-- dbgpt/rag/schemalinker/__init__.py | 1 + dbgpt/rag/schemalinker/base_linker.py | 20 ++- dbgpt/rag/schemalinker/schema_linking.py | 52 +++--- dbgpt/rag/summary/__init__.py | 18 +++ dbgpt/rag/summary/db_summary.py | 43 +++-- dbgpt/rag/summary/db_summary_client.py | 32 ++-- dbgpt/rag/summary/rdbms_db_summary.py | 28 +++- dbgpt/rag/text_splitter/__init__.py | 23 +++ dbgpt/rag/text_splitter/pre_text_splitter.py | 14 +- dbgpt/rag/text_splitter/text_splitter.py | 148 +++++++++++------- dbgpt/rag/text_splitter/token_splitter.py | 8 +- dbgpt/util/speech/say.py | 2 +- examples/rag/db_schema_rag_example.py | 2 +- examples/rag/embedding_rag_example.py | 4 +- examples/rag/rag_embedding_api_example.py | 2 +- examples/rag/rewrite_rag_example.py | 14 +- .../rag/simple_dbschema_retriever_example.py | 31 ++-- examples/rag/simple_rag_embedding_example.py | 32 ++-- examples/rag/simple_rag_retriever_example.py | 43 ++--- examples/rag/summary_extractor_example.py | 17 +- requirements/lint-requirements.txt | 3 + 70 files changed, 1132 insertions(+), 584 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 70bae183a..55abad8f7 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -2,6 +2,21 @@ exclude = /tests/ # plugins = pydantic.mypy +[mypy-dbgpt.app.*] +follow_imports = skip + +[mypy-dbgpt.datasource.*] +follow_imports = skip + +[mypy-dbgpt.storage.*] +follow_imports = skip + +[mypy-dbgpt.serve.*] +follow_imports = skip + +[mypy-dbgpt.util.*] +follow_imports = skip + [mypy-graphviz.*] ignore_missing_imports = True @@ -17,4 +32,29 @@ ignore_missing_imports = True [mypy-pydantic.*] strict_optional = False ignore_missing_imports = True +follow_imports = skip + +[mypy-sentence_transformers.*] +ignore_missing_imports = True + +[mypy-InstructorEmbedding.*] +ignore_missing_imports = True + +[mypy-llama_index.*] +ignore_missing_imports = True + +[mypy-pptx.*] +ignore_missing_imports = True + +[mypy-docx.*] +ignore_missing_imports = True + +[mypy-markdown.*] +ignore_missing_imports = True + +[mypy-auto_gpt_plugin_template.*] +ignore_missing_imports = True + +[mypy-spacy.*] +ignore_missing_imports = True follow_imports = skip \ No newline at end of file diff --git a/Makefile b/Makefile index e8f092feb..2780c15d0 100644 --- a/Makefile +++ b/Makefile @@ -49,6 +49,7 @@ fmt: setup ## Format Python code # TODO: Use flake8 to enforce Python style guide. # https://flake8.pycqa.org/en/latest/ $(VENV_BIN)/flake8 dbgpt/core/ + $(VENV_BIN)/flake8 dbgpt/rag/ # TODO: More package checks with flake8. .PHONY: fmt-check @@ -58,6 +59,7 @@ fmt-check: setup ## Check Python code formatting and style without making change $(VENV_BIN)/black --check --extend-exclude="examples/notebook" . $(VENV_BIN)/blackdoc --check dbgpt examples $(VENV_BIN)/flake8 dbgpt/core/ + $(VENV_BIN)/flake8 dbgpt/rag/ # $(VENV_BIN)/blackdoc --check dbgpt examples # $(VENV_BIN)/flake8 dbgpt/core/ @@ -76,6 +78,7 @@ test-doc: $(VENV)/.testenv ## Run doctests mypy: $(VENV)/.testenv ## Run mypy checks # https://github.com/python/mypy $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/ + $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ # TODO: More package checks with mypy. .PHONY: coverage diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index f39b1921e..22f0d8b13 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -62,7 +62,7 @@ def __init__(self) -> None: if self.zhipu_proxy_api_key: os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv( - "ZHIPU_MODEL_VERSION" + "ZHIPU_MODEL_VERSION", "" ) # wenxin @@ -74,7 +74,9 @@ def __init__(self) -> None: os.environ[ "wenxin_proxyllm_proxy_api_secret" ] = self.wenxin_proxy_api_secret - os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version + os.environ["wenxin_proxyllm_proxyllm_backend"] = ( + self.wenxin_model_version or "" + ) # xunfei spark self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") @@ -84,8 +86,10 @@ def __init__(self) -> None: if self.spark_proxy_api_key and self.spark_proxy_api_secret: os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret - os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version - os.environ["spark_proxyllm_proxy_api_app_id"] = self.spark_proxy_api_appid + os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version or "" + os.environ["spark_proxyllm_proxy_api_app_id"] = ( + self.spark_proxy_api_appid or "" + ) # baichuan proxy self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") @@ -108,12 +112,10 @@ def __init__(self) -> None: self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID") self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID") - self.use_mac_os_tts = False - self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS") + self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS", "False") == "True" self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.exit_key = os.getenv("EXIT_KEY", "n") - self.image_provider = os.getenv("IMAGE_PROVIDER", True) self.image_size = int(os.getenv("IMAGE_SIZE", 256)) self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") @@ -131,10 +133,7 @@ def __init__(self) -> None: self.prompt_template_registry = PromptTemplateRegistry() ### Related configuration of built-in commands - self.command_registry = [] - - ### Relate configuration of display commands - self.command_dispaly = [] + self.command_registry = [] # type: ignore disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") if disabled_command_categories: @@ -151,7 +150,7 @@ def __init__(self) -> None: ### The associated configuration parameters of the plug-in control the loading and use of the plug-in self.plugins: List["AutoGPTPluginTemplate"] = [] - self.plugins_openai = [] + self.plugins_openai = [] # type: ignore self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true" self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") @@ -274,6 +273,6 @@ def __init__(self) -> None: self.MODEL_CACHE_MAX_MEMORY_MB: int = int( os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256) ) - self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv( + self.MODEL_CACHE_STORAGE_DISK_DIR: Optional[str] = os.getenv( "MODEL_CACHE_STORAGE_DISK_DIR" ) diff --git a/dbgpt/component.py b/dbgpt/component.py index 26fbac0fa..7f96fbdc3 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -210,7 +210,7 @@ def register_instance(self, instance: T) -> T: def get_component( self, name: Union[str, ComponentType], - component_type: Type[T], + component_type: Type, default_component=_EMPTY_DEFAULT_COMPONENT, or_register_component: Optional[Type[T]] = None, *args, diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index ddbe9b612..f6778b5f4 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -81,7 +81,7 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: if system_app: executor = system_app.get_component( ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory - ).create() + ).create() # type: ignore else: executor = DefaultExecutorFactory().create() DAGVar.set_executor(executor) diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py index 5eb415e62..f890cba6c 100644 --- a/dbgpt/datasource/rdbms/base.py +++ b/dbgpt/datasource/rdbms/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from urllib.parse import quote from urllib.parse import quote_plus as urlquote @@ -499,7 +499,7 @@ def get_show_create_table(self, table_name): ans = cursor.fetchall() return ans[0][1] - def get_fields(self, table_name): + def get_fields(self, table_name) -> List[Tuple]: """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute( diff --git a/dbgpt/datasource/rdbms/conn_clickhouse.py b/dbgpt/datasource/rdbms/conn_clickhouse.py index 06277050e..ed1c1dbb1 100644 --- a/dbgpt/datasource/rdbms/conn_clickhouse.py +++ b/dbgpt/datasource/rdbms/conn_clickhouse.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import sqlparse from sqlalchemy import MetaData, text @@ -145,7 +145,7 @@ def get_columns(self, table_name: str) -> List[Dict]: for name, column_type, _, _, comment in fields[0] ] - def get_fields(self, table_name): + def get_fields(self, table_name) -> List[Tuple]: """Get column fields about specified table.""" session = self.client diff --git a/dbgpt/datasource/rdbms/conn_doris.py b/dbgpt/datasource/rdbms/conn_doris.py index 10b290af0..ddf28d397 100644 --- a/dbgpt/datasource/rdbms/conn_doris.py +++ b/dbgpt/datasource/rdbms/conn_doris.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional +from typing import Any, Iterable, List, Optional, Tuple from urllib.parse import quote from urllib.parse import quote_plus as urlquote @@ -68,7 +68,7 @@ def get_users(self): """Get user info.""" return [] - def get_fields(self, table_name): + def get_fields(self, table_name) -> List[Tuple]: """Get column fields about specified table.""" cursor = self.get_session().execute( text( diff --git a/dbgpt/datasource/rdbms/conn_postgresql.py b/dbgpt/datasource/rdbms/conn_postgresql.py index 7470d5d90..71b63da07 100644 --- a/dbgpt/datasource/rdbms/conn_postgresql.py +++ b/dbgpt/datasource/rdbms/conn_postgresql.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional +from typing import Any, Iterable, List, Optional, Tuple from urllib.parse import quote from urllib.parse import quote_plus as urlquote @@ -85,7 +85,7 @@ def get_users(self): print("postgresql get users error: ", e) return [] - def get_fields(self, table_name): + def get_fields(self, table_name) -> List[Tuple]: """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute( diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index 9eb9f735f..c9f442d89 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -4,7 +4,7 @@ import logging import os import tempfile -from typing import Any, Iterable, Optional +from typing import Any, Iterable, List, Optional, Tuple from sqlalchemy import create_engine, text @@ -58,7 +58,7 @@ def get_show_create_table(self, table_name): ans = cursor.fetchall() return ans[0][0] - def get_fields(self, table_name): + def get_fields(self, table_name) -> List[Tuple]: """Get column fields about specified table.""" cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')")) fields = cursor.fetchall() diff --git a/dbgpt/datasource/rdbms/conn_starrocks.py b/dbgpt/datasource/rdbms/conn_starrocks.py index 0c79b5d2a..26079b1fb 100644 --- a/dbgpt/datasource/rdbms/conn_starrocks.py +++ b/dbgpt/datasource/rdbms/conn_starrocks.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional +from typing import Any, Iterable, List, Optional, Tuple from urllib.parse import quote from urllib.parse import quote_plus as urlquote @@ -68,7 +68,7 @@ def get_users(self): """Get user info.""" return [] - def get_fields(self, table_name, db_name="database()"): + def get_fields(self, table_name, db_name="database()") -> List[Tuple]: """Get column fields about specified table.""" session = self._db_sessions() if db_name != "database()": diff --git a/dbgpt/rag/__init__.py b/dbgpt/rag/__init__.py index e69de29bb..99f150c72 100644 --- a/dbgpt/rag/__init__.py +++ b/dbgpt/rag/__init__.py @@ -0,0 +1 @@ +"""Module of RAG.""" diff --git a/dbgpt/rag/chunk.py b/dbgpt/rag/chunk.py index 75a061f4f..7fe75b0a5 100644 --- a/dbgpt/rag/chunk.py +++ b/dbgpt/rag/chunk.py @@ -1,46 +1,49 @@ +"""Chunk document schema.""" + import json import uuid from typing import Any, Dict -from pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field class Document(BaseModel): """Document including document content, document metadata.""" - content: str = (Field(default="", description="document text content"),) + content: str = Field(default="", description="document text content") - metadata: Dict[str, Any] = ( - Field( - default_factory=dict, - description="metadata fields", - ), + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="metadata fields", ) def set_content(self, content: str) -> None: - """Set the content""" + """Set document content.""" self.content = content def get_content(self) -> str: + """Get document content.""" return self.content @classmethod def langchain2doc(cls, document): - """Transformation from Langchain to Chunk Document format.""" + """Transform Langchain to Document format.""" metadata = document.metadata or {} return cls(content=document.page_content, metadata=metadata) @classmethod def doc2langchain(cls, chunk): - """Transformation from Chunk to Langchain Document format.""" + """Transform Document to Langchain format.""" from langchain.schema import Document as LCDocument return LCDocument(page_content=chunk.content, metadata=chunk.metadata) class Chunk(Document): - """ - Document Chunk including chunk content, chunk metadata, chunk summary, chunk relations. + """The chunk document schema. + + Document Chunk including chunk content, chunk metadata, chunk summary, chunk + relations. """ chunk_id: str = Field( @@ -48,11 +51,9 @@ class Chunk(Document): ) content: str = Field(default="", description="chunk text content") - metadata: Dict[str, Any] = ( - Field( - default_factory=dict, - description="metadata fields", - ), + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="metadata fields", ) score: float = Field(default=0.0, description="chunk text similarity score") summary: str = Field(default="", description="chunk text summary") @@ -62,22 +63,27 @@ class Chunk(Document): ) def to_dict(self, **kwargs: Any) -> Dict[str, Any]: + """Convert Chunk to dict.""" data = self.dict(**kwargs) data["class_name"] = self.class_name() return data def to_json(self, **kwargs: Any) -> str: + """Convert Chunk to json.""" data = self.to_dict(**kwargs) return json.dumps(data) def __hash__(self): + """Hash function.""" return hash((self.chunk_id,)) def __eq__(self, other): + """Equal function.""" return self.chunk_id == other.chunk_id @classmethod def from_dict(cls, data: Dict[str, Any], **kwargs: Any): # type: ignore + """Create Chunk from dict.""" if isinstance(kwargs, dict): data.update(kwargs) @@ -86,31 +92,32 @@ def from_dict(cls, data: Dict[str, Any], **kwargs: Any): # type: ignore @classmethod def from_json(cls, data_str: str, **kwargs: Any): # type: ignore + """Create Chunk from json.""" data = json.loads(data_str) return cls.from_dict(data, **kwargs) @classmethod def langchain2chunk(cls, document): - """Transformation from Langchain to Chunk Document format.""" + """Transform Langchain to Chunk format.""" metadata = document.metadata or {} - return cls(content=document.page_content, metadata=document.metadata) + return cls(content=document.page_content, metadata=metadata) @classmethod def llamaindex2chunk(cls, node): - """Transformation from LLama-Index to Chunk Document format.""" + """Transform llama-index to Chunk format.""" metadata = node.metadata or {} return cls(content=node.content, metadata=metadata) @classmethod def chunk2langchain(cls, chunk): - """Transformation from Chunk to Langchain Document format.""" + """Transform Chunk to Langchain format.""" from langchain.schema import Document as LCDocument return LCDocument(page_content=chunk.content, metadata=chunk.metadata) @classmethod def chunk2llamaindex(cls, chunk): - """Transformation from Chunk to LLama-Index Document format.""" + """Transform Chunk to llama-index format.""" from llama_index.schema import TextNode return TextNode(text=chunk.content, metadata=chunk.metadata) diff --git a/dbgpt/rag/chunk_manager.py b/dbgpt/rag/chunk_manager.py index d462a958d..550fa6d0e 100644 --- a/dbgpt/rag/chunk_manager.py +++ b/dbgpt/rag/chunk_manager.py @@ -1,3 +1,5 @@ +"""Module for ChunkManager.""" + from enum import Enum from typing import Any, List, Optional @@ -9,7 +11,7 @@ class SplitterType(Enum): - """splitter type""" + """The type of splitter.""" LANGCHAIN = "langchain" LLAMA_INDEX = "llama-index" @@ -17,7 +19,7 @@ class SplitterType(Enum): class ChunkParameters(BaseModel): - """ChunkParameters""" + """The parameters for chunking.""" chunk_strategy: str = Field( default=None, @@ -52,15 +54,16 @@ class ChunkParameters(BaseModel): class ChunkManager: - """ChunkManager""" + """Manager for chunks.""" def __init__( self, - knowledge: Knowledge = None, + knowledge: Knowledge, chunk_parameter: Optional[ChunkParameters] = None, extractor: Optional[Extractor] = None, ): - """ + """Create a new ChunkManager with the given knowledge. + Args: knowledge: (Knowledge) Knowledge datasource. chunk_parameter: (Optional[ChunkParameter]) Chunk parameter. @@ -72,10 +75,11 @@ def __init__( self._chunk_parameters = chunk_parameter or ChunkParameters() self._chunk_strategy = ( chunk_parameter.chunk_strategy - or self._knowledge.default_chunk_strategy().name + if chunk_parameter and chunk_parameter.chunk_strategy + else self._knowledge.default_chunk_strategy().name ) - self._text_splitter = chunk_parameter.text_splitter - self._splitter_type = chunk_parameter.splitter_type + self._text_splitter = self._chunk_parameters.text_splitter + self._splitter_type = self._chunk_parameters.splitter_type def split(self, documents) -> List[Chunk]: """Split a document into chunks.""" @@ -92,18 +96,18 @@ def split(self, documents) -> List[Chunk]: def split_with_summary( self, document: Any, chunk_strategy: ChunkStrategy ) -> List[Chunk]: - """Split a document into chunks and summary""" - + """Split a document into chunks and summary.""" raise NotImplementedError @property def chunk_parameters(self) -> ChunkParameters: + """Get chunk parameters.""" return self._chunk_parameters def set_text_splitter( self, text_splitter, - splitter_type: Optional[SplitterType] = SplitterType.LANGCHAIN, + splitter_type: SplitterType = SplitterType.LANGCHAIN, ) -> None: """Add text splitter.""" self._text_splitter = text_splitter @@ -112,7 +116,7 @@ def set_text_splitter( def get_text_splitter( self, ) -> Any: - """get text splitter.""" + """Return text splitter.""" return self._select_text_splitter() def _select_text_splitter( @@ -121,7 +125,7 @@ def _select_text_splitter( """Select text splitter by chunk strategy.""" if self._text_splitter: return self._text_splitter - if not self._chunk_strategy or "Automatic" == self._chunk_strategy: + if not self._chunk_strategy or self._chunk_strategy == "Automatic": self._chunk_strategy = self._knowledge.default_chunk_strategy().name if self._chunk_strategy not in [ support_chunk_strategy.name @@ -131,7 +135,8 @@ def _select_text_splitter( if self._knowledge.document_type(): current_type = self._knowledge.document_type().value raise ValueError( - f"{current_type} knowledge not supported chunk strategy {self._chunk_strategy} " + f"{current_type} knowledge not supported chunk strategy " + f"{self._chunk_strategy} " ) strategy = ChunkStrategy[self._chunk_strategy] return strategy.match( diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index a4d6d2914..dcf799f3a 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -1,16 +1,24 @@ -from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory -from .embeddings import ( +"""Module for embedding related classes and functions.""" + +from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory # noqa: F401 +from .embeddings import ( # noqa: F401 Embeddings, + HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings, + HuggingFaceInferenceAPIEmbeddings, + HuggingFaceInstructEmbeddings, JinaEmbeddings, OpenAPIEmbeddings, ) __ALL__ = [ - "OpenAPIEmbeddings", "Embeddings", + "HuggingFaceBgeEmbeddings", "HuggingFaceEmbeddings", + "HuggingFaceInferenceAPIEmbeddings", + "HuggingFaceInstructEmbeddings", "JinaEmbeddings", - "EmbeddingFactory", + "OpenAPIEmbeddings", "DefaultEmbeddingFactory", + "EmbeddingFactory", ] diff --git a/dbgpt/rag/embedding/embedding_factory.py b/dbgpt/rag/embedding/embedding_factory.py index 96864a1ff..83237bb76 100644 --- a/dbgpt/rag/embedding/embedding_factory.py +++ b/dbgpt/rag/embedding/embedding_factory.py @@ -1,9 +1,11 @@ +"""EmbeddingFactory class and DefaultEmbeddingFactory class.""" + from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any, Optional, Type -from dbgpt.component import BaseComponent +from dbgpt.component import BaseComponent, SystemApp from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings if TYPE_CHECKING: @@ -17,25 +19,46 @@ class EmbeddingFactory(BaseComponent, ABC): @abstractmethod def create( - self, model_name: str = None, embedding_cls: Type = None + self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None ) -> "Embeddings": - """Create embedding""" + """Create an embedding instance. + + Args: + model_name (str): The model name. + embedding_cls (Type): The embedding class. + + Returns: + Embeddings: The embedding instance. + """ class DefaultEmbeddingFactory(EmbeddingFactory): + """The default embedding factory.""" + def __init__( - self, system_app=None, default_model_name: str = None, **kwargs: Any + self, + system_app: Optional[SystemApp] = None, + default_model_name: Optional[str] = None, + **kwargs: Any, ) -> None: + """Create a new DefaultEmbeddingFactory.""" super().__init__(system_app=system_app) self._default_model_name = default_model_name self.kwargs = kwargs def init_app(self, system_app): + """Init the app.""" pass def create( - self, model_name: str = None, embedding_cls: Type = None + self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None ) -> "Embeddings": + """Create an embedding instance. + + Args: + model_name (str): The model name. + embedding_cls (Type): The embedding class. + """ if not model_name: model_name = self._default_model_name diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 7ccc78ee5..f02623b5f 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -1,3 +1,5 @@ +"""Embedding implementations.""" + import asyncio from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional @@ -21,9 +23,11 @@ class Embeddings(ABC): - """Interface for embedding models.""" + """Interface for embedding models. - """refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings""" + Refer to `Langchain Embeddings `_. + """ @abstractmethod def embed_documents(self, texts: List[str]) -> List[List[float]]: @@ -48,12 +52,16 @@ async def aembed_query(self, text: str) -> List[float]: class HuggingFaceEmbeddings(BaseModel, Embeddings): """HuggingFace sentence_transformers embedding models. + To use, you should have the ``sentence_transformers`` python package installed. - Refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings + + Refer to `Langchain Embeddings `_. + Example: .. code-block:: python - from .embeddings import HuggingFaceEmbeddings + from dbgpt.rag.embedding import HuggingFaceEmbeddings model_name = "sentence-transformers/all-mpnet-base-v2" model_kwargs = {"device": "cpu"} @@ -69,8 +77,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): model_name: str = DEFAULT_MODEL_NAME """Model name to use.""" cache_folder: Optional[str] = None - """Path to store models. - Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" + """Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME + environment variable.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Keyword arguments to pass to the model.""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -141,7 +149,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): Example: .. code-block:: python - from langchain.embeddings import HuggingFaceInstructEmbeddings + from dbgpt.rag.embeddings import HuggingFaceInstructEmbeddings model_name = "hkunlp/instructor-large" model_kwargs = {"device": "cpu"} @@ -157,8 +165,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): model_name: str = DEFAULT_INSTRUCT_MODEL """Model name to use.""" cache_folder: Optional[str] = None - """Path to store models. - Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" + """Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME + environment variable.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Keyword arguments to pass to the model.""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -216,11 +224,14 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): """HuggingFace BGE sentence_transformers embedding models. To use, you should have the ``sentence_transformers`` python package installed. - refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings + + refer to `Langchain Embeddings `_. + Example: .. code-block:: python - from langchain.embeddings import HuggingFaceBgeEmbeddings + from dbgpt.rag.embeddings import HuggingFaceBgeEmbeddings model_name = "BAAI/bge-large-en" model_kwargs = {"device": "cpu"} @@ -389,28 +400,30 @@ def _handle_request_result(res: requests.Response) -> List[List[float]]: class JinaEmbeddings(BaseModel, Embeddings): - """ + """Jina AI embeddings. + This class is used to get embeddings for a list of texts using the Jina AI API. - It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". + It requires an API key and a model name. The default model name is + "jina-embeddings-v2-base-en". """ api_url: Any #: :meta private: session: Any #: :meta private: api_key: str - """our API key for the Jina AI API..""" + """API key for the Jina AI API..""" model_name: str = "jina-embeddings-v2-base-en" - """he name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en".""" + """The name of the model to use for text embeddings. Defaults to + "jina-embeddings-v2-base-en".""" def __init__(self, **kwargs): - """ - Initialize the JinaEmbeddings. - """ + """Create a new JinaEmbeddings instance.""" super().__init__(**kwargs) try: import requests except ImportError: raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" + "The requests python package is not installed. Please install it with " + "`pip install requests`" ) self.api_url = "https://api.jina.ai/v1/embeddings" self.session = requests.Session() @@ -432,10 +445,10 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: resp = self.session.post( # type: ignore self.api_url, json={"input": texts, "model": self.model_name} ) - return _handle_request_result(res) + return _handle_request_result(resp) def embed_query(self, text: str) -> List[float]: - """Compute query embeddings using a HuggingFace transformer model. + """Compute query embeddings using a Jina AI embedding model. Args: text: The text to embed. @@ -447,12 +460,12 @@ def embed_query(self, text: str) -> List[float]: class OpenAPIEmbeddings(BaseModel, Embeddings): - """This class is used to get embeddings for a list of texts using the API. + """The OpenAPI embeddings. + This class is used to get embeddings for a list of texts using the API. This API is compatible with the OpenAI Embedding API. Examples: - Using OpenAI's API: .. code-block:: python @@ -504,7 +517,6 @@ class OpenAPIEmbeddings(BaseModel, Embeddings): ) texts = ["Hello, world!", "How are you?"] openai_embeddings.embed_documents(texts) - """ api_url: str = Field( @@ -521,7 +533,7 @@ class OpenAPIEmbeddings(BaseModel, Embeddings): default=60, description="The timeout for the request in seconds." ) - session: requests.Session = None + session: Optional[requests.Session] = None class Config: """Configuration for this pydantic object.""" diff --git a/dbgpt/rag/extractor/__init__.py b/dbgpt/rag/extractor/__init__.py index e69de29bb..cdcb10ba0 100644 --- a/dbgpt/rag/extractor/__init__.py +++ b/dbgpt/rag/extractor/__init__.py @@ -0,0 +1,5 @@ +"""Module for extracting information.""" +from .base import Extractor +from .summary import SummaryExtractor + +__all__ = ["Extractor", "SummaryExtractor"] diff --git a/dbgpt/rag/extractor/base.py b/dbgpt/rag/extractor/base.py index b9049b176..6ba57f231 100644 --- a/dbgpt/rag/extractor/base.py +++ b/dbgpt/rag/extractor/base.py @@ -1,3 +1,4 @@ +"""Base Extractor Base class.""" from abc import ABC, abstractmethod from typing import List @@ -6,14 +7,18 @@ class Extractor(ABC): - """Extractor Base class, it's apply for Summary Extractor, Keyword Extractor, Triplets Extractor, Question Extractor, etc.""" + """Base Extractor Base class. + + It's apply for Summary Extractor, Keyword Extractor, Triplets Extractor, Question + Extractor, etc. + """ def __init__(self, llm_client: LLMClient) -> None: """Initialize the Extractor.""" self._llm_client = llm_client def extract(self, chunks: List[Chunk]) -> str: - """Extracts chunks. + """Return extracted metadata from chunks. Args: chunks (List[Chunk]): extract metadata from chunks @@ -30,7 +35,7 @@ async def aextract(self, chunks: List[Chunk]) -> str: @abstractmethod def _extract(self, chunks: List[Chunk]) -> str: - """Extracts chunks. + """Return extracted metadata from chunks. Args: chunks (List[Chunk]): extract metadata from chunks diff --git a/dbgpt/rag/extractor/summary.py b/dbgpt/rag/extractor/summary.py index 0e1cb1b77..28669a0f7 100644 --- a/dbgpt/rag/extractor/summary.py +++ b/dbgpt/rag/extractor/summary.py @@ -1,3 +1,5 @@ +"""Summary Extractor, it can extract document summary.""" + from typing import List, Optional from dbgpt._private.llm_metadata import LLMMetadata @@ -13,16 +15,21 @@ """ SUMMARY_PROMPT_TEMPLATE_EN = """ -Write a quick summary of the following context: +Write a quick summary of the following context: {context} -the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters. +the summary should be as concise as possible and not overly lengthy.Please keep the +answer within approximately 200 characters. """ -REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。""" +REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context} +请根据你之前推理的内容进行总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。 +""" REFINE_SUMMARY_TEMPLATE_EN = """ -We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below. -\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3. +We have provided an existing summary up to a certain point: {context}, We have the +opportunity to refine the existing summary (only if needed) with some more context +below. \nBased on the previous reasoning, please summarize the final conclusion in +accordance with points 1.2.and 3. """ @@ -31,18 +38,29 @@ class SummaryExtractor(Extractor): def __init__( self, - llm_client: Optional[LLMClient], - model_name: Optional[str] = None, + llm_client: LLMClient, + model_name: str, llm_metadata: Optional[LLMMetadata] = None, language: Optional[str] = "en", - max_iteration_with_llm: Optional[int] = 5, - concurrency_limit_with_llm: Optional[int] = 3, + max_iteration_with_llm: int = 5, + concurrency_limit_with_llm: int = 3, ): + """Create SummaryExtractor. + + Args: + llm_client: (Optional[LLMClient]): The LLM client. Defaults to None. + model_name: str + llm_metadata: LLMMetadata + language: (Optional[str]): The language of the prompt. Defaults to "en". + max_iteration_with_llm: (Optional[int]): The max iteration with llm. + Defaults to 5. + concurrency_limit_with_llm: (Optional[int]): The concurrency limit with llm. + Defaults to 3. + """ self._llm_client = llm_client self._model_name = model_name - self.llm_metadata = llm_metadata or LLMMetadata + self.llm_metadata = llm_metadata self._language = language - self._concurrency_limit_with_llm = concurrency_limit_with_llm self._prompt_template = ( SUMMARY_PROMPT_TEMPLATE_EN if language == "en" @@ -55,23 +73,15 @@ def __init__( ) self._concurrency_limit_with_llm = concurrency_limit_with_llm self._max_iteration_with_llm = max_iteration_with_llm - self._concurrency_limit_with_llm = concurrency_limit_with_llm - - """Initialize the Extractor. - Args: - llm_client: (Optional[LLMClient]): The LLM client. Defaults to None. - model_name: str - llm_metadata: LLMMetadata - language: (Optional[str]): The language of the prompt. Defaults to "en". - max_iteration_with_llm: (Optional[int]): The max iteration with llm. Defaults to 5. - concurrency_limit_with_llm: (Optional[int]): The concurrency limit with llm. Defaults to 3. - """ async def _aextract(self, chunks: List[Chunk]) -> str: - """async document extract summary + """Return extracted metadata from chunks of async. + Args: - - model_name: str - - chunk_docs: List[Document] + chunks (List[Chunk]): extract metadata from chunks + + Returns: + str: The summary of the documents. """ texts = [doc.content for doc in chunks] from dbgpt.util.prompt_util import PromptHelper @@ -95,9 +105,13 @@ async def _aextract(self, chunks: List[Chunk]) -> str: return summary_outs[0] def _extract(self, chunks: List[Chunk]) -> str: - """document extract summary + """Return summary of the documents. + Args: - - chunk_docs: List[Document] + chunks(List[Chunk]): list of chunks + + Returns: + summary: str """ loop = utils.get_or_create_event_loop() return loop.run_until_complete(self._aextract(chunks=chunks)) @@ -106,7 +120,10 @@ async def _mapreduce_extract_summary( self, docs: List[str], ) -> str: - """Extract summary by mapreduce mode + """Return the summary of the documents. + + Extract summary by mapreduce mode. + map -> multi async call llm to generate summary reduce -> merge the summaries by map process Args: @@ -132,10 +149,12 @@ async def _mapreduce_extract_summary( async def _llm_run_tasks( self, chunk_texts: List[str], prompt_template: str ) -> List[str]: - """llm run tasks + """Run llm tasks. + Args: chunk_texts: List[str] prompt_template: str + Returns: summary_outs: List[str] """ diff --git a/dbgpt/rag/graph/__init__.py b/dbgpt/rag/graph/__init__.py index e69de29bb..bcbf461b9 100644 --- a/dbgpt/rag/graph/__init__.py +++ b/dbgpt/rag/graph/__init__.py @@ -0,0 +1 @@ +"""Graph module for RAG.""" diff --git a/dbgpt/rag/knowledge/__init__.py b/dbgpt/rag/knowledge/__init__.py index e69de29bb..0fa28c518 100644 --- a/dbgpt/rag/knowledge/__init__.py +++ b/dbgpt/rag/knowledge/__init__.py @@ -0,0 +1,29 @@ +"""Module Of Knowledge.""" + +from .base import ChunkStrategy, Knowledge, KnowledgeType # noqa: F401 +from .csv import CSVKnowledge # noqa: F401 +from .docx import DocxKnowledge # noqa: F401 +from .factory import KnowledgeFactory # noqa: F401 +from .html import HTMLKnowledge # noqa: F401 +from .markdown import MarkdownKnowledge # noqa: F401 +from .pdf import PDFKnowledge # noqa: F401 +from .pptx import PPTXKnowledge # noqa: F401 +from .string import StringKnowledge # noqa: F401 +from .txt import TXTKnowledge # noqa: F401 +from .url import URLKnowledge # noqa: F401 + +__ALL__ = [ + "KnowledgeFactory", + "Knowledge", + "KnowledgeType", + "ChunkStrategy", + "CSVKnowledge", + "DocxKnowledge", + "HTMLKnowledge", + "MarkdownKnowledge", + "PDFKnowledge", + "PPTXKnowledge", + "StringKnowledge", + "TXTKnowledge", + "URLKnowledge", +] diff --git a/dbgpt/rag/knowledge/base.py b/dbgpt/rag/knowledge/base.py index af149b164..21187c61b 100644 --- a/dbgpt/rag/knowledge/base.py +++ b/dbgpt/rag/knowledge/base.py @@ -1,19 +1,23 @@ +"""Module for Knowledge Base.""" + from abc import ABC, abstractmethod from enum import Enum -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple, Type from dbgpt.rag.chunk import Document from dbgpt.rag.text_splitter.text_splitter import ( - CharacterTextSplitter, MarkdownHeaderTextSplitter, PageTextSplitter, ParagraphTextSplitter, RecursiveCharacterTextSplitter, SeparatorTextSplitter, + TextSplitter, ) class DocumentType(Enum): + """Document Type Enum.""" + PDF = "pdf" CSV = "csv" MARKDOWN = "md" @@ -24,27 +28,40 @@ class DocumentType(Enum): class KnowledgeType(Enum): + """Knowledge Type Enum.""" + DOCUMENT = "DOCUMENT" URL = "URL" TEXT = "TEXT" @property def type(self): + """Get type.""" return DocumentType @classmethod - def get_by_value(cls, value): - """Get Enum member by value""" + def get_by_value(cls, value) -> "KnowledgeType": + """Get Enum member by value. + + Args: + value(any): value + + Returns: + KnowledgeType: Enum member + """ for member in cls: if member.value == value: return member raise ValueError(f"{value} is not a valid value for {cls.__name__}") +_STRATEGY_ENUM_TYPE = Tuple[Type[TextSplitter], List, str, str] + + class ChunkStrategy(Enum): - """chunk strategy""" + """Chunk Strategy Enum.""" - CHUNK_BY_SIZE = ( + CHUNK_BY_SIZE: _STRATEGY_ENUM_TYPE = ( RecursiveCharacterTextSplitter, [ { @@ -63,8 +80,13 @@ class ChunkStrategy(Enum): "chunk size", "split document by chunk size", ) - CHUNK_BY_PAGE = (PageTextSplitter, [], "page", "split document by page") - CHUNK_BY_PARAGRAPH = ( + CHUNK_BY_PAGE: _STRATEGY_ENUM_TYPE = ( + PageTextSplitter, + [], + "page", + "split document by page", + ) + CHUNK_BY_PARAGRAPH: _STRATEGY_ENUM_TYPE = ( ParagraphTextSplitter, [ { @@ -77,7 +99,7 @@ class ChunkStrategy(Enum): "paragraph", "split document by paragraph", ) - CHUNK_BY_SEPARATOR = ( + CHUNK_BY_SEPARATOR: _STRATEGY_ENUM_TYPE = ( SeparatorTextSplitter, [ { @@ -90,13 +112,14 @@ class ChunkStrategy(Enum): "param_name": "enable_merge", "param_type": "boolean", "default_value": False, - "description": "Whether to merge according to the chunk_size after splitting by the separator.", + "description": "Whether to merge according to the chunk_size after " + "splitting by the separator.", }, ], "separator", "split document by separator", ) - CHUNK_BY_MARKDOWN_HEADER = ( + CHUNK_BY_MARKDOWN_HEADER: _STRATEGY_ENUM_TYPE = ( MarkdownHeaderTextSplitter, [], "markdown header", @@ -104,24 +127,26 @@ class ChunkStrategy(Enum): ) def __init__(self, splitter_class, parameters, alias, description): + """Create a new ChunkStrategy with the given splitter_class.""" self.splitter_class = splitter_class self.parameters = parameters self.alias = alias self.description = description - def match(self, *args, **kwargs): + def match(self, *args, **kwargs) -> TextSplitter: + """Match and build splitter.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} return self.value[0](*args, **kwargs) class Knowledge(ABC): - type: KnowledgeType = None + """Knowledge Base Class.""" def __init__( self, path: Optional[str] = None, knowledge_type: Optional[KnowledgeType] = None, - data_loader: Optional = None, + data_loader: Optional[Any] = None, **kwargs: Any, ) -> None: """Initialize with Knowledge arguments.""" @@ -130,30 +155,31 @@ def __init__( self._data_loader = data_loader def load(self): - """Load knowledge from data_loader""" + """Load knowledge from data_loader.""" documents = self._load() return self._postprocess(documents) @classmethod + @abstractmethod def type(cls) -> KnowledgeType: - """Get knowledge type""" + """Get knowledge type.""" @classmethod def document_type(cls) -> Any: - """Get document type""" + """Get document type.""" return None def _postprocess(self, docs: List[Document]) -> List[Document]: - """Post process knowledge from data_loader""" + """Post process knowledge from data_loader.""" return docs @abstractmethod def _load(self): - """Preprocess knowledge from data_loader""" + """Preprocess knowledge from data_loader.""" @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: - """support chunk strategy""" + """Return supported chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_PAGE, @@ -162,11 +188,11 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: ChunkStrategy.CHUNK_BY_SEPARATOR, ] - def default_chunk_strategy(self) -> ChunkStrategy: - return ChunkStrategy.CHUNK_BY_SIZE + @classmethod + def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy. - def support_chunk_strategy(self): - return [ - ChunkStrategy.CHUNK_BY_SIZE, - ChunkStrategy.CHUNK_BY_SEPARATOR, - ] + Returns: + ChunkStrategy: default chunk strategy + """ + return ChunkStrategy.CHUNK_BY_SIZE diff --git a/dbgpt/rag/knowledge/csv.py b/dbgpt/rag/knowledge/csv.py index ec41cd7f5..bbe564a13 100644 --- a/dbgpt/rag/knowledge/csv.py +++ b/dbgpt/rag/knowledge/csv.py @@ -1,3 +1,4 @@ +"""CSV Knowledge.""" import csv from typing import Any, List, Optional @@ -11,7 +12,7 @@ class CSVKnowledge(Knowledge): - """CSV Knowledge""" + """CSV Knowledge.""" def __init__( self, @@ -22,13 +23,14 @@ def __init__( loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize csv with Knowledge arguments. + """Create CSV Knowledge with Knowledge arguments. + Args: - file_path:(Optional[str]) file path - knowledge_type:(KnowledgeType) knowledge type - source_column:(Optional[str]) source column - encoding:(Optional[str]) csv encoding - loader:(Optional[Any]) loader + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type + source_column(str, optional): source column + encoding(str, optional): csv encoding + loader(Any, optional): loader """ self._path = file_path self._type = knowledge_type @@ -37,11 +39,13 @@ def __init__( self._source_column = source_column def _load(self) -> List[Document]: - """Load csv document from loader""" + """Load csv document from loader.""" if self._loader: documents = self._loader.load() else: docs = [] + if not self._path: + raise ValueError("file path is required") with open(self._path, newline="", encoding=self._encoding) as csvfile: csv_reader = csv.DictReader(csvfile) for i, row in enumerate(csv_reader): @@ -59,7 +63,8 @@ def _load(self) -> List[Document]: ) except KeyError: raise ValueError( - f"Source column '{self._source_column}' not found in CSV file." + f"Source column '{self._source_column}' not found in CSV " + f"file." ) metadata = {"source": source, "row": i} doc = Document(content=content, metadata=metadata) @@ -70,6 +75,7 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_SEPARATOR, @@ -77,12 +83,15 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls) -> KnowledgeType: + """Knowledge type of CSV.""" return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Return document type.""" return DocumentType.CSV diff --git a/dbgpt/rag/knowledge/docx.py b/dbgpt/rag/knowledge/docx.py index ae1075c77..b1c7e48b7 100644 --- a/dbgpt/rag/knowledge/docx.py +++ b/dbgpt/rag/knowledge/docx.py @@ -1,3 +1,4 @@ +"""Docx Knowledge.""" from typing import Any, List, Optional import docx @@ -12,7 +13,7 @@ class DocxKnowledge(Knowledge): - """Docx Knowledge""" + """Docx Knowledge.""" def __init__( self, @@ -22,12 +23,13 @@ def __init__( loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize with Knowledge arguments. + """Create Docx Knowledge with Knowledge arguments. + Args: - file_path:(Optional[str]) file path - knowledge_type:(KnowledgeType) knowledge type - encoding:(Optional[str]) csv encoding - loader:(Optional[Any]) loader + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type + encoding(str, optional): csv encoding + loader(Any, optional): loader """ self._path = file_path self._type = knowledge_type @@ -35,7 +37,7 @@ def __init__( self._encoding = encoding def _load(self) -> List[Document]: - """Load docx document from loader""" + """Load docx document from loader.""" if self._loader: documents = self._loader.load() else: @@ -54,6 +56,7 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_PARAGRAPH, @@ -62,12 +65,15 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls) -> KnowledgeType: + """Return knowledge type.""" return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Return document type.""" return DocumentType.DOCX diff --git a/dbgpt/rag/knowledge/factory.py b/dbgpt/rag/knowledge/factory.py index 622d18b41..b60ef0b8f 100644 --- a/dbgpt/rag/knowledge/factory.py +++ b/dbgpt/rag/knowledge/factory.py @@ -1,4 +1,5 @@ -from typing import List, Optional +"""Knowledge Factory to create knowledge from file path and url.""" +from typing import List, Optional, Type from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType from dbgpt.rag.knowledge.string import StringKnowledge @@ -6,17 +7,18 @@ class KnowledgeFactory: - """Knowledge Factory to create knowledge from file path and url""" + """Knowledge Factory to create knowledge from file path and url.""" def __init__( self, file_path: Optional[str] = None, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, ): - """Initialize with Knowledge Factory arguments. + """Create Knowledge Factory with file path and knowledge type. + Args: - param file_path: path of the file to convert - param knowledge_type: type of knowledge + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type """ self._file_path = file_path self._knowledge_type = knowledge_type @@ -24,16 +26,16 @@ def __init__( @classmethod def create( cls, - datasource: Optional[str] = None, - knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, + datasource: str = "", + knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, ): - """create knowledge from file path, url or text + """Create knowledge from file path, url or text. + Args: datasource: path of the file to convert knowledge_type: type of knowledge Examples: - .. code-block:: python from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -62,17 +64,16 @@ def create( @classmethod def from_file_path( cls, - file_path: Optional[str] = None, + file_path: str = "", knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, ) -> Knowledge: - """Create knowledge from path + """Create knowledge from path. Args: param file_path: path of the file to convert param knowledge_type: type of knowledge Examples: - .. code-block:: python from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -81,7 +82,6 @@ def from_file_path( datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT, ) - """ factory = cls(file_path=file_path, knowledge_type=knowledge_type) return factory._select_document_knowledge( @@ -90,17 +90,16 @@ def from_file_path( @staticmethod def from_url( - url: Optional[str] = None, - knowledge_type: Optional[KnowledgeType] = KnowledgeType.URL, + url: str = "", + knowledge_type: KnowledgeType = KnowledgeType.URL, ) -> Knowledge: - """Create knowledge from url + """Create knowledge from url. Args: param url: url of the file to convert param knowledge_type: type of knowledge Examples: - .. code-block:: python from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -108,7 +107,6 @@ def from_url( url_knowlege = KnowledgeFactory.create( datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL ) - """ return URLKnowledge( url=url, @@ -117,10 +115,11 @@ def from_url( @staticmethod def from_text( - text: str = None, - knowledge_type: Optional[KnowledgeType] = KnowledgeType.TEXT, + text: str = "", + knowledge_type: KnowledgeType = KnowledgeType.TEXT, ) -> Knowledge: - """Create knowledge from text + """Create knowledge from text. + Args: param text: text to convert param knowledge_type: type of knowledge @@ -131,7 +130,7 @@ def from_text( ) def _select_document_knowledge(self, **kwargs): - """Select document knowledge from file path""" + """Select document knowledge from file path.""" extension = self._file_path.rsplit(".", 1)[-1] knowledge_classes = self._get_knowledge_subclasses() implementation = None @@ -144,26 +143,26 @@ def _select_document_knowledge(self, **kwargs): @classmethod def all_types(cls): - """get all knowledge types""" + """Get all knowledge types.""" return [knowledge.type().value for knowledge in cls._get_knowledge_subclasses()] @classmethod - def subclasses(cls): - """get all knowledge subclasses""" + def subclasses(cls) -> List["Type[Knowledge]"]: + """Get all knowledge subclasses.""" return cls._get_knowledge_subclasses() @staticmethod - def _get_knowledge_subclasses() -> List[Knowledge]: - """get all knowledge subclasses""" - from dbgpt.rag.knowledge.base import Knowledge - from dbgpt.rag.knowledge.csv import CSVKnowledge - from dbgpt.rag.knowledge.docx import DocxKnowledge - from dbgpt.rag.knowledge.html import HTMLKnowledge - from dbgpt.rag.knowledge.markdown import MarkdownKnowledge - from dbgpt.rag.knowledge.pdf import PDFKnowledge - from dbgpt.rag.knowledge.pptx import PPTXKnowledge - from dbgpt.rag.knowledge.string import StringKnowledge - from dbgpt.rag.knowledge.txt import TXTKnowledge - from dbgpt.rag.knowledge.url import URLKnowledge + def _get_knowledge_subclasses() -> List["Type[Knowledge]"]: + """Get all knowledge subclasses.""" + from dbgpt.rag.knowledge.base import Knowledge # noqa: F401 + from dbgpt.rag.knowledge.csv import CSVKnowledge # noqa: F401 + from dbgpt.rag.knowledge.docx import DocxKnowledge # noqa: F401 + from dbgpt.rag.knowledge.html import HTMLKnowledge # noqa: F401 + from dbgpt.rag.knowledge.markdown import MarkdownKnowledge # noqa: F401 + from dbgpt.rag.knowledge.pdf import PDFKnowledge # noqa: F401 + from dbgpt.rag.knowledge.pptx import PPTXKnowledge # noqa: F401 + from dbgpt.rag.knowledge.string import StringKnowledge # noqa: F401 + from dbgpt.rag.knowledge.txt import TXTKnowledge # noqa: F401 + from dbgpt.rag.knowledge.url import URLKnowledge # noqa: F401 return Knowledge.__subclasses__() diff --git a/dbgpt/rag/knowledge/html.py b/dbgpt/rag/knowledge/html.py index 7fa3e545f..3af6fccd2 100644 --- a/dbgpt/rag/knowledge/html.py +++ b/dbgpt/rag/knowledge/html.py @@ -1,3 +1,4 @@ +"""HTML Knowledge.""" from typing import Any, List, Optional import chardet @@ -12,7 +13,7 @@ class HTMLKnowledge(Knowledge): - """HTML Knowledge""" + """HTML Knowledge.""" def __init__( self, @@ -21,21 +22,24 @@ def __init__( loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize with Knowledge arguments. + """Create HTML Knowledge with Knowledge arguments. + Args: - file_path:(Optional[str]) file path - knowledge_type:(KnowledgeType) knowledge type - loader:(Optional[Any]) loader + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type + loader(Any, optional): loader """ self._path = file_path self._type = knowledge_type self._loader = loader def _load(self) -> List[Document]: - """Load html document from loader""" + """Load html document from loader.""" if self._loader: documents = self._loader.load() else: + if not self._path: + raise ValueError("file path is required") with open(self._path, "rb") as f: raw_text = f.read() result = chardet.detect(raw_text) @@ -49,10 +53,9 @@ def _load(self) -> List[Document]: return [Document.langchain2doc(lc_document) for lc_document in documents] def _postprocess(self, documents: List[Document]): - i = 0 - for d in documents: - import markdown + import markdown + for i, d in enumerate(documents): content = markdown.markdown(d.content) from bs4 import BeautifulSoup @@ -61,11 +64,11 @@ def _postprocess(self, documents: List[Document]): tag.extract() documents[i].content = soup.get_text() documents[i].content = documents[i].content.replace("\n", " ") - i += 1 return documents @classmethod def support_chunk_strategy(cls): + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_SEPARATOR, @@ -73,12 +76,15 @@ def support_chunk_strategy(cls): @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls) -> KnowledgeType: + """Return knowledge type.""" return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Return document type.""" return DocumentType.HTML diff --git a/dbgpt/rag/knowledge/json.py b/dbgpt/rag/knowledge/json.py index e69de29bb..42d8f61e9 100644 --- a/dbgpt/rag/knowledge/json.py +++ b/dbgpt/rag/knowledge/json.py @@ -0,0 +1 @@ +"""Knowledge JSON.""" diff --git a/dbgpt/rag/knowledge/markdown.py b/dbgpt/rag/knowledge/markdown.py index 90270fd0f..a2e92706f 100644 --- a/dbgpt/rag/knowledge/markdown.py +++ b/dbgpt/rag/knowledge/markdown.py @@ -1,3 +1,4 @@ +"""Markdown Knowledge.""" from typing import Any, List, Optional from dbgpt.rag.chunk import Document @@ -10,7 +11,7 @@ class MarkdownKnowledge(Knowledge): - """Markdown Knowledge""" + """Markdown Knowledge.""" def __init__( self, @@ -20,12 +21,13 @@ def __init__( loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize with Knowledge arguments. + """Create Markdown Knowledge with Knowledge arguments. + Args: - file_path:(Optional[str]) file path - knowledge_type:(KnowledgeType) knowledge type - encoding:(Optional[str]) encoding - loader:(Optional[Any]) loader + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type + encoding(str, optional): csv encoding + loader(Any, optional): loader """ self._path = file_path self._type = knowledge_type @@ -33,10 +35,12 @@ def __init__( self._encoding = encoding def _load(self) -> List[Document]: - """Load markdown document from loader""" + """Load markdown document from loader.""" if self._loader: documents = self._loader.load() else: + if not self._path: + raise ValueError("file path is required") with open(self._path, encoding=self._encoding, errors="ignore") as f: markdown_text = f.read() metadata = {"source": self._path} @@ -46,6 +50,7 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_MARKDOWN_HEADER, @@ -54,12 +59,15 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_MARKDOWN_HEADER @classmethod def type(cls) -> KnowledgeType: + """Return knowledge type.""" return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Return document type.""" return DocumentType.MARKDOWN diff --git a/dbgpt/rag/knowledge/pdf.py b/dbgpt/rag/knowledge/pdf.py index d6807f5d9..fab8e26e2 100644 --- a/dbgpt/rag/knowledge/pdf.py +++ b/dbgpt/rag/knowledge/pdf.py @@ -1,3 +1,4 @@ +"""PDF Knowledge.""" from typing import Any, List, Optional from dbgpt.rag.chunk import Document @@ -10,21 +11,23 @@ class PDFKnowledge(Knowledge): - """PDF Knowledge""" + """PDF Knowledge.""" def __init__( self, file_path: Optional[str] = None, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, - loader: Optional = None, + loader: Optional[Any] = None, language: Optional[str] = "zh", **kwargs: Any, ) -> None: - """Initialize with PDF Knowledge arguments. + """Create PDF Knowledge with Knowledge arguments. + Args: - file_path:(Optional[str]) file path - knowledge_type:(KnowledgeType) knowledge type - loader:(Optional[Any]) loader + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type + loader(Any, optional): loader + language(str, optional): language """ self._path = file_path self._type = knowledge_type @@ -32,7 +35,7 @@ def __init__( self._language = language def _load(self) -> List[Document]: - """Load pdf document from loader""" + """Load pdf document from loader.""" if self._loader: documents = self._loader.load() else: @@ -40,11 +43,13 @@ def _load(self) -> List[Document]: pages = [] documents = [] + if not self._path: + raise ValueError("file path is required") with open(self._path, "rb") as file: reader = pypdf.PdfReader(file) for page_num in range(len(reader.pages)): - page = reader.pages[page_num] - pages.append((page.extract_text(), page_num)) + _page = reader.pages[page_num] + pages.append((_page.extract_text(), page_num)) # cleaned_pages = [] for page, page_num in pages: @@ -53,10 +58,9 @@ def _load(self) -> List[Document]: cleaned_lines = [] for line in lines: if self._language == "en": - words = list(line) + words = list(line) # noqa: F841 else: - words = line.split() - digits = [word for word in words if any(i.isdigit() for i in word)] + words = line.split() # noqa: F841 cleaned_lines.append(line) page = "\n".join(cleaned_lines) # cleaned_pages.append(page) @@ -69,6 +73,7 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_PAGE, @@ -77,12 +82,15 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls) -> KnowledgeType: + """Return knowledge type.""" return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Document type of PDF.""" return DocumentType.PDF diff --git a/dbgpt/rag/knowledge/pptx.py b/dbgpt/rag/knowledge/pptx.py index 90fc337bf..4f4a35e08 100644 --- a/dbgpt/rag/knowledge/pptx.py +++ b/dbgpt/rag/knowledge/pptx.py @@ -1,3 +1,4 @@ +"""PPTX Knowledge.""" from typing import Any, List, Optional from dbgpt.rag.chunk import Document @@ -10,17 +11,18 @@ class PPTXKnowledge(Knowledge): - """PPTX Knowledge""" + """PPTX Knowledge.""" def __init__( self, file_path: Optional[str] = None, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, - loader: Optional = None, + loader: Optional[Any] = None, language: Optional[str] = "zh", **kwargs: Any, ) -> None: - """Initialize with PDF Knowledge arguments. + """Create PPTX knowledge with PDF Knowledge arguments. + Args: file_path:(Optional[str]) file path knowledge_type:(KnowledgeType) knowledge type @@ -32,7 +34,7 @@ def __init__( self._language = language def _load(self) -> List[Document]: - """Load pdf document from loader""" + """Load pdf document from loader.""" if self._loader: documents = self._loader.load() else: @@ -53,6 +55,11 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy. + + Returns: + List[ChunkStrategy]: support chunk strategy + """ return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_PAGE, @@ -61,12 +68,27 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy. + + Returns: + ChunkStrategy: default chunk strategy + """ return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls) -> KnowledgeType: + """Knowledge type of PPTX. + + Returns: + KnowledgeType: knowledge type + """ return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Document type of PPTX. + + Returns: + DocumentType: document type + """ return DocumentType.PPTX diff --git a/dbgpt/rag/knowledge/string.py b/dbgpt/rag/knowledge/string.py index 44dc74541..fa007c6eb 100644 --- a/dbgpt/rag/knowledge/string.py +++ b/dbgpt/rag/knowledge/string.py @@ -1,3 +1,4 @@ +"""String Knowledge.""" from typing import Any, List, Optional from dbgpt.rag.chunk import Document @@ -5,22 +6,23 @@ class StringKnowledge(Knowledge): - """String Knowledge""" + """String Knowledge.""" def __init__( self, - text: str = None, + text: str = "", knowledge_type: KnowledgeType = KnowledgeType.TEXT, encoding: Optional[str] = "utf-8", loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize with Knowledge arguments. + """Create String knowledge parameters. + Args: - text:(str) text - knowledge_type:(KnowledgeType) knowledge type - encoding:(encoding) csv encoding - loader:(loader) loader + text(str): text + knowledge_type(KnowledgeType): knowledge type + encoding(str): encoding + loader(Any): loader """ self._text = text self._type = knowledge_type @@ -28,21 +30,25 @@ def __init__( self._encoding = encoding def _load(self) -> List[Document]: - """load raw text from loader""" + """Load raw text from loader.""" metadata = {"source": "raw text"} docs = [Document(content=self._text, metadata=metadata)] return docs @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_SEPARATOR, ] + @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls): + """Return knowledge type.""" return KnowledgeType.TEXT diff --git a/dbgpt/rag/knowledge/txt.py b/dbgpt/rag/knowledge/txt.py index 7be946133..6c33b136b 100644 --- a/dbgpt/rag/knowledge/txt.py +++ b/dbgpt/rag/knowledge/txt.py @@ -1,3 +1,4 @@ +"""TXT Knowledge.""" from typing import Any, List, Optional import chardet @@ -12,7 +13,7 @@ class TXTKnowledge(Knowledge): - """TXT Knowledge""" + """TXT Knowledge.""" def __init__( self, @@ -21,21 +22,24 @@ def __init__( loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize with Knowledge arguments. + """Create TXT Knowledge with Knowledge arguments. + Args: - file_path:(Optional[str]) file path - knowledge_type:(KnowledgeType) knowledge type - loader:(Optional[Any]) loader + file_path(str, optional): file path + knowledge_type(KnowledgeType, optional): knowledge type + loader(Any, optional): loader """ self._path = file_path self._type = knowledge_type self._loader = loader def _load(self) -> List[Document]: - """Load txt document from loader""" + """Load txt document from loader.""" if self._loader: documents = self._loader.load() else: + if not self._path: + raise ValueError("file path is required") with open(self._path, "rb") as f: raw_text = f.read() result = chardet.detect(raw_text) @@ -50,6 +54,7 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls): + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_SEPARATOR, @@ -57,12 +62,15 @@ def support_chunk_strategy(cls): @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls) -> KnowledgeType: + """Return knowledge type.""" return KnowledgeType.DOCUMENT @classmethod def document_type(cls) -> DocumentType: + """Return document type.""" return DocumentType.TXT diff --git a/dbgpt/rag/knowledge/url.py b/dbgpt/rag/knowledge/url.py index f1f4ab78c..aaeb74956 100644 --- a/dbgpt/rag/knowledge/url.py +++ b/dbgpt/rag/knowledge/url.py @@ -1,3 +1,4 @@ +"""URL Knowledge.""" from typing import Any, List, Optional from dbgpt.rag.chunk import Document @@ -5,22 +6,25 @@ class URLKnowledge(Knowledge): + """URL Knowledge.""" + def __init__( self, - url: Optional[str] = None, + url: str = "", knowledge_type: KnowledgeType = KnowledgeType.URL, source_column: Optional[str] = None, encoding: Optional[str] = "utf-8", loader: Optional[Any] = None, **kwargs: Any, ) -> None: - """Initialize with Knowledge arguments. + """Create URL Knowledge with Knowledge arguments. + Args: - url:(Optional[str]) url - knowledge_type:(KnowledgeType) knowledge type - source_column:(Optional[str]) source column - encoding:(Optional[str]) csv encoding - loader:(Optional[Any]) loader + url(str, optional): url + knowledge_type(KnowledgeType, optional): knowledge type + source_column(str, optional): source column + encoding(str, optional): csv encoding + loader(Any, optional): loader """ self._path = url self._type = knowledge_type @@ -29,7 +33,7 @@ def __init__( self._source_column = source_column def _load(self) -> List[Document]: - """Fetch URL document from loader""" + """Fetch URL document from loader.""" if self._loader: documents = self._loader.load() else: @@ -41,6 +45,7 @@ def _load(self) -> List[Document]: @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: + """Return support chunk strategy.""" return [ ChunkStrategy.CHUNK_BY_SIZE, ChunkStrategy.CHUNK_BY_SEPARATOR, @@ -48,8 +53,10 @@ def support_chunk_strategy(cls) -> List[ChunkStrategy]: @classmethod def default_chunk_strategy(cls) -> ChunkStrategy: + """Return default chunk strategy.""" return ChunkStrategy.CHUNK_BY_SIZE @classmethod def type(cls): + """Return knowledge type.""" return KnowledgeType.URL diff --git a/dbgpt/rag/operators/__init__.py b/dbgpt/rag/operators/__init__.py index e69de29bb..41bafcbbb 100644 --- a/dbgpt/rag/operators/__init__.py +++ b/dbgpt/rag/operators/__init__.py @@ -0,0 +1,19 @@ +"""Module for RAG operators.""" + +from .datasource import DatasourceRetrieverOperator # noqa: F401 +from .db_schema import DBSchemaRetrieverOperator # noqa: F401 +from .embedding import EmbeddingRetrieverOperator # noqa: F401 +from .knowledge import KnowledgeOperator # noqa: F401 +from .rerank import RerankOperator # noqa: F401 +from .rewrite import QueryRewriteOperator # noqa: F401 +from .summary import SummaryAssemblerOperator # noqa: F401 + +__all__ = [ + "DatasourceRetrieverOperator", + "DBSchemaRetrieverOperator", + "EmbeddingRetrieverOperator", + "KnowledgeOperator", + "RerankOperator", + "QueryRewriteOperator", + "SummaryAssemblerOperator", +] diff --git a/dbgpt/rag/operators/datasource.py b/dbgpt/rag/operators/datasource.py index e86723f12..b236c2cff 100644 --- a/dbgpt/rag/operators/datasource.py +++ b/dbgpt/rag/operators/datasource.py @@ -1,3 +1,5 @@ +"""Datasource operator for RDBMS database.""" + from typing import Any from dbgpt.core.interface.operators.retriever import RetrieverOperator @@ -6,10 +8,14 @@ class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]): + """The Datasource Retriever Operator.""" + def __init__(self, connection: RDBMSDatabase, **kwargs): + """Create a new DatasourceRetrieverOperator.""" super().__init__(**kwargs) self._connection = connection def retrieve(self, input_value: Any) -> Any: + """Retrieve the database summary.""" summary = _parse_db_summary(self._connection) return summary diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index 192fd489a..49fffa3e4 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -1,6 +1,7 @@ +"""The DBSchema Retriever Operator.""" + from typing import Any, Optional -from dbgpt.core.awel.task.base import IN from dbgpt.core.interface.operators.retriever import RetrieverOperator from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.retriever.db_schema import DBSchemaRetriever @@ -9,19 +10,22 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]): """The DBSchema Retriever Operator. + Args: connection (RDBMSDatabase): The connection. top_k (int, optional): The top k. Defaults to 4. - vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. + vector_store_connector (VectorStoreConnector, optional): The vector store + connector. Defaults to None. """ def __init__( self, + vector_store_connector: VectorStoreConnector, top_k: int = 4, connection: Optional[RDBMSDatabase] = None, - vector_store_connector: Optional[VectorStoreConnector] = None, **kwargs ): + """Create a new DBSchemaRetrieverOperator.""" super().__init__(**kwargs) self._retriever = DBSchemaRetriever( top_k=top_k, @@ -29,8 +33,9 @@ def __init__( vector_store_connector=vector_store_connector, ) - def retrieve(self, query: IN) -> Any: - """retrieve table schemas. + def retrieve(self, query: Any) -> Any: + """Retrieve the table schemas. + Args: query (IN): query. """ diff --git a/dbgpt/rag/operators/embedding.py b/dbgpt/rag/operators/embedding.py index f4810b6ce..a5e2096ab 100644 --- a/dbgpt/rag/operators/embedding.py +++ b/dbgpt/rag/operators/embedding.py @@ -1,7 +1,8 @@ +"""Embedding retriever operator.""" + from functools import reduce from typing import Any, Optional -from dbgpt.core.awel.task.base import IN from dbgpt.core.interface.operators.retriever import RetrieverOperator from dbgpt.rag.retriever.embedding import EmbeddingRetriever from dbgpt.rag.retriever.rerank import Ranker @@ -10,25 +11,29 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]): + """The Embedding Retriever Operator.""" + def __init__( self, + vector_store_connector: VectorStoreConnector, top_k: int, - score_threshold: Optional[float] = 0.3, + score_threshold: float = 0.3, query_rewrite: Optional[QueryRewrite] = None, - rerank: Ranker = None, - vector_store_connector: VectorStoreConnector = None, + rerank: Optional[Ranker] = None, **kwargs ): + """Create a new EmbeddingRetrieverOperator.""" super().__init__(**kwargs) self._score_threshold = score_threshold self._retriever = EmbeddingRetriever( + vector_store_connector=vector_store_connector, top_k=top_k, query_rewrite=query_rewrite, rerank=rerank, - vector_store_connector=vector_store_connector, ) - def retrieve(self, query: IN) -> Any: + def retrieve(self, query: Any) -> Any: + """Retrieve the candidates.""" if isinstance(query, str): return self._retriever.retrieve_with_scores(query, self._score_threshold) elif isinstance(query, list): diff --git a/dbgpt/rag/operators/knowledge.py b/dbgpt/rag/operators/knowledge.py index e7e74a19c..cfa572260 100644 --- a/dbgpt/rag/operators/knowledge.py +++ b/dbgpt/rag/operators/knowledge.py @@ -1,4 +1,6 @@ -from typing import Any, List, Optional +"""Knowledge Operator.""" + +from typing import Any, Optional from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( @@ -8,7 +10,6 @@ Parameter, ViewMetadata, ) -from dbgpt.core.awel.task.base import IN from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -76,6 +77,7 @@ def __init__( **kwargs ): """Init the query rewrite operator. + Args: knowledge_type: (Optional[KnowledgeType]) The knowledge type. """ @@ -83,8 +85,8 @@ def __init__( self._datasource = datasource self._knowledge_type = KnowledgeType.get_by_value(knowledge_type) - async def map(self, datasource: IN) -> Knowledge: - """knowledge operator.""" + async def map(self, datasource: Any) -> Knowledge: + """Create knowledge from datasource.""" if self._datasource: datasource = self._datasource return await self.blocking_func_to_async( diff --git a/dbgpt/rag/operators/rerank.py b/dbgpt/rag/operators/rerank.py index bb6485b6e..cde2c8b08 100644 --- a/dbgpt/rag/operators/rerank.py +++ b/dbgpt/rag/operators/rerank.py @@ -1,11 +1,9 @@ +"""The Rerank Operator.""" from typing import Any, List, Optional -from dbgpt.core import LLMClient from dbgpt.core.awel import MapOperator -from dbgpt.core.awel.task.base import IN from dbgpt.rag.chunk import Chunk -from dbgpt.rag.retriever.rerank import DefaultRanker -from dbgpt.rag.retriever.rewrite import QueryRewrite +from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker class RerankOperator(MapOperator[Any, Any]): @@ -13,12 +11,13 @@ class RerankOperator(MapOperator[Any, Any]): def __init__( self, - topk: Optional[int] = 3, - algorithm: Optional[str] = "default", - rank_fn: Optional[callable] = None, + topk: int = 3, + algorithm: str = "default", + rank_fn: Optional[RANK_FUNC] = None, **kwargs ): - """Init the query rewrite operator. + """Create a new RerankOperator. + Args: topk (int): The number of the candidates. algorithm (Optional[str]): The rerank algorithm name. @@ -31,10 +30,11 @@ def __init__( rank_fn=rank_fn, ) - async def map(self, candidates_with_scores: IN) -> List[Chunk]: - """rerank the candidates. + async def map(self, candidates_with_scores: List[Chunk]) -> List[Chunk]: + """Rerank the candidates. + Args: - candidates_with_scores (IN): The candidates with scores. + candidates_with_scores (List[Chunk]): The candidates with scores. Returns: List[Chunk]: The reranked candidates. """ diff --git a/dbgpt/rag/operators/rewrite.py b/dbgpt/rag/operators/rewrite.py index d911c0b0a..1f3e1c9b3 100644 --- a/dbgpt/rag/operators/rewrite.py +++ b/dbgpt/rag/operators/rewrite.py @@ -1,13 +1,14 @@ +"""The rewrite operator.""" + from typing import Any, List, Optional from dbgpt.core import LLMClient from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata -from dbgpt.core.awel.task.base import IN from dbgpt.rag.retriever.rewrite import QueryRewrite -class QueryRewriteOperator(MapOperator[Any, Any]): +class QueryRewriteOperator(MapOperator[dict, Any]): """The Rewrite Operator.""" metadata = ViewMetadata( @@ -22,7 +23,8 @@ class QueryRewriteOperator(MapOperator[Any, Any]): IOField.build_from( "rewritten queries", "queries", - List[str], + str, + is_list=True, description="rewritten queries", ) ], @@ -31,8 +33,6 @@ class QueryRewriteOperator(MapOperator[Any, Any]): "LLM Client", "llm_client", LLMClient, - optional=True, - default=None, description="The LLM Client.", ), Parameter.build_from( @@ -65,13 +65,14 @@ class QueryRewriteOperator(MapOperator[Any, Any]): def __init__( self, - llm_client: Optional[LLMClient], - model_name: Optional[str] = None, + llm_client: LLMClient, + model_name: str = "gpt-3.5-turbo", language: Optional[str] = "en", nums: Optional[int] = 1, **kwargs ): """Init the query rewrite operator. + Args: llm_client (Optional[LLMClient]): The LLM client. model_name (Optional[str]): The model name. @@ -86,10 +87,12 @@ def __init__( language=language, ) - async def map(self, query_context: IN) -> List[str]: + async def map(self, query_context: dict) -> List[str]: """Rewrite the query.""" query = query_context.get("query") context = query_context.get("context") + if not query: + raise ValueError("query is required") return await self._rewrite.rewrite( origin_query=query, context=context, nums=self._nums ) diff --git a/dbgpt/rag/operators/schema_linking.py b/dbgpt/rag/operators/schema_linking.py index a10609f4f..59f9c7185 100644 --- a/dbgpt/rag/operators/schema_linking.py +++ b/dbgpt/rag/operators/schema_linking.py @@ -1,3 +1,8 @@ +"""Simple schema linking operator. + +Warning: This operator is in development and is not yet ready for production use. +""" + from typing import Any, Optional from dbgpt.core import LLMClient @@ -12,14 +17,15 @@ class SchemaLinkingOperator(MapOperator[Any, Any]): def __init__( self, + connection: RDBMSDatabase, + model_name: str, + llm: LLMClient, top_k: int = 5, - connection: Optional[RDBMSDatabase] = None, - llm: Optional[LLMClient] = None, - model_name: Optional[str] = None, vector_store_connector: Optional[VectorStoreConnector] = None, **kwargs ): - """Init the schema linking operator + """Create the schema linking operator. + Args: connection (RDBMSDatabase): The connection. llm (Optional[LLMClient]): base llm @@ -35,10 +41,12 @@ def __init__( ) async def map(self, query: str) -> str: - """retrieve table schemas. + """Retrieve the table schemas with llm. + Args: query (str): query. + Return: - str: schema info + str: schema information. """ return str(await self._schema_linking.schema_linking_with_llm(query)) diff --git a/dbgpt/rag/operators/summary.py b/dbgpt/rag/operators/summary.py index 4f9ce0ae6..2eb83bb6b 100644 --- a/dbgpt/rag/operators/summary.py +++ b/dbgpt/rag/operators/summary.py @@ -1,14 +1,17 @@ +"""The summary operator.""" + from typing import Any, Optional from dbgpt.core import LLMClient from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata -from dbgpt.core.awel.task.base import IN from dbgpt.rag.knowledge.base import Knowledge from dbgpt.serve.rag.assembler.summary import SummaryAssembler from dbgpt.serve.rag.operators.base import AssemblerOperator class SummaryAssemblerOperator(AssemblerOperator[Any, Any]): + """The summary assembler operator.""" + metadata = ViewMetadata( label="Summary Operator", name="summary_assembler_operator", @@ -81,14 +84,15 @@ def __init__( concurrency_limit_with_llm: Optional[int] = 3, **kwargs ): - """ - Init the summary assemble operator. + """Create the summary assemble operator. + Args: llm_client: (Optional[LLMClient]) The LLM client. model_name: (Optional[str]) The model name. language: (Optional[str]) The prompt language. max_iteration_with_llm: (Optional[int]) The max iteration with llm. - concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm. + concurrency_limit_with_llm: (Optional[int]) The concurrency limit with + llm. """ super().__init__(**kwargs) self._llm_client = llm_client @@ -97,7 +101,7 @@ def __init__( self._max_iteration_with_llm = max_iteration_with_llm self._concurrency_limit_with_llm = concurrency_limit_with_llm - async def map(self, knowledge: IN) -> Any: + async def map(self, knowledge: Knowledge) -> str: """Assemble the summary.""" assembler = SummaryAssembler.load_from_knowledge( knowledge=knowledge, @@ -109,6 +113,6 @@ async def map(self, knowledge: IN) -> Any: ) return await assembler.generate_summary() - def assemble(self, knowledge: IN) -> Any: - """assemble knowledge for input value.""" + def assemble(self, knowledge: Knowledge) -> Any: + """Assemble the summary.""" pass diff --git a/dbgpt/rag/retriever/__init__.py b/dbgpt/rag/retriever/__init__.py index e69de29bb..874302a0b 100644 --- a/dbgpt/rag/retriever/__init__.py +++ b/dbgpt/rag/retriever/__init__.py @@ -0,0 +1,18 @@ +"""Module Of Retriever.""" + +from .base import BaseRetriever, RetrieverStrategy # noqa: F401 +from .db_schema import DBSchemaRetriever # noqa: F401 +from .embedding import EmbeddingRetriever # noqa: F401 +from .rerank import DefaultRanker, Ranker, RRFRanker # noqa: F401 +from .rewrite import QueryRewrite # noqa: F401 + +__all__ = [ + "RetrieverStrategy", + "BaseRetriever", + "DBSchemaRetriever", + "EmbeddingRetriever", + "Ranker", + "DefaultRanker", + "RRFRanker", + "QueryRewrite", +] diff --git a/dbgpt/rag/retriever/base.py b/dbgpt/rag/retriever/base.py index 86c8133d1..0ecf28e85 100644 --- a/dbgpt/rag/retriever/base.py +++ b/dbgpt/rag/retriever/base.py @@ -1,12 +1,14 @@ +"""Base retriever module.""" from abc import ABC, abstractmethod from enum import Enum -from typing import List, Tuple +from typing import List from dbgpt.rag.chunk import Chunk class RetrieverStrategy(str, Enum): """Retriever strategy. + Args: - EMBEDDING: embedding retriever - KEYWORD: keyword retriever @@ -22,38 +24,48 @@ class BaseRetriever(ABC): """Base retriever.""" def retrieve(self, query: str) -> List[Chunk]: - """ + """Retrieve knowledge chunks. + Args: query (str): query text + Returns: List[Chunk]: list of chunks """ return self._retrieve(query) async def aretrieve(self, query: str) -> List[Chunk]: - """ + """Retrieve knowledge chunks. + Args: query (str): async query text + Returns: List[Chunk]: list of chunks """ return await self._aretrieve(query) def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]: - """ + """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + + Returns: + List[Chunk]: list of chunks """ return self._retrieve_with_score(query, score_threshold) async def aretrieve_with_scores( self, query: str, score_threshold: float ) -> List[Chunk]: - """ + """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + Returns: List[Chunk]: list of chunks """ @@ -62,8 +74,10 @@ async def aretrieve_with_scores( @abstractmethod def _retrieve(self, query: str) -> List[Chunk]: """Retrieve knowledge chunks. + Args: query (str): query text + Returns: List[Chunk]: list of chunks """ @@ -71,8 +85,10 @@ def _retrieve(self, query: str) -> List[Chunk]: @abstractmethod async def _aretrieve(self, query: str) -> List[Chunk]: """Async Retrieve knowledge chunks. + Args: query (str): query text + Returns: List[Chunk]: list of chunks """ @@ -80,9 +96,11 @@ async def _aretrieve(self, query: str) -> List[Chunk]: @abstractmethod def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + Returns: List[Chunk]: list of chunks """ @@ -92,9 +110,11 @@ async def _aretrieve_with_score( self, query: str, score_threshold: float ) -> List[Chunk]: """Async Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + Returns: List[Chunk]: list of chunks """ diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index 72fe425f0..bfd74dd1f 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,5 +1,6 @@ +"""DBSchema retriever.""" from functools import reduce -from typing import List, Optional +from typing import List, Optional, cast from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.chunk import Chunk @@ -15,23 +16,23 @@ class DBSchemaRetriever(BaseRetriever): def __init__( self, + vector_store_connector: VectorStoreConnector, top_k: int = 4, connection: Optional[RDBMSDatabase] = None, query_rewrite: bool = False, - rerank: Ranker = None, - vector_store_connector: Optional[VectorStoreConnector] = None, + rerank: Optional[Ranker] = None, **kwargs ): - """ + """Create DBSchemaRetriever. + Args: + vector_store_connector (VectorStoreConnector): vector store connector top_k (int): top k connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. query_rewrite (bool): query rewrite rerank (Ranker): rerank - vector_store_connector (VectorStoreConnector): vector store connector Examples: - .. code-block:: python from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect @@ -78,12 +79,9 @@ def _create_temporary_connection(): top_k=3, vector_store_connector=vector_connector ) chunks = retriever.retrieve("show columns from table") - print( - f"db struct rag example results:{[chunk.content for chunk in chunks]}" - ) - + result = [chunk.content for chunk in chunks] + print(f"db struct rag example results:{result}") """ - self._top_k = top_k self._connection = connection self._query_rewrite = query_rewrite @@ -95,8 +93,12 @@ def _create_temporary_connection(): def _retrieve(self, query: str) -> List[Chunk]: """Retrieve knowledge chunks. + Args: query (str): query text + + Returns: + List[Chunk]: list of chunks """ if self._need_embeddings: queries = [query] @@ -104,32 +106,45 @@ def _retrieve(self, query: str) -> List[Chunk]: self._vector_store_connector.similar_search(query, self._top_k) for query in queries ] - candidates = reduce(lambda x, y: x + y, candidates) - return candidates + return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) else: + if not self._connection: + raise RuntimeError("RDBMSDatabase connection is required.") table_summaries = _parse_db_summary(self._connection) return [Chunk(content=table_summary) for table_summary in table_summaries] def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + + Returns: + List[Chunk]: list of chunks """ return self._retrieve(query) async def _aretrieve(self, query: str) -> List[Chunk]: """Retrieve knowledge chunks. + Args: query (str): query text + + Returns: + List[Chunk]: list of chunks """ if self._need_embeddings: queries = [query] candidates = [self._similarity_search(query) for query in queries] - candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) - return candidates + result_candidates = await run_async_tasks( + tasks=candidates, concurrency_limit=1 + ) + return result_candidates else: - from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary + from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 + _parse_db_summary, + ) table_summaries = await run_async_tasks( tasks=[self._aparse_db_summary()], concurrency_limit=1 @@ -140,6 +155,7 @@ async def _aretrieve_with_score( self, query: str, score_threshold: float ) -> List[Chunk]: """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold @@ -157,4 +173,6 @@ async def _aparse_db_summary(self) -> List[str]: """Similar search.""" from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary - return _parse_db_summary() + if not self._connection: + raise RuntimeError("RDBMSDatabase connection is required.") + return _parse_db_summary(self._connection) diff --git a/dbgpt/rag/retriever/embedding.py b/dbgpt/rag/retriever/embedding.py index fd4c687b4..e55c38465 100644 --- a/dbgpt/rag/retriever/embedding.py +++ b/dbgpt/rag/retriever/embedding.py @@ -1,5 +1,6 @@ +"""Embedding retriever.""" from functools import reduce -from typing import List, Optional +from typing import List, Optional, cast from dbgpt.rag.chunk import Chunk from dbgpt.rag.retriever.base import BaseRetriever @@ -15,12 +16,13 @@ class EmbeddingRetriever(BaseRetriever): def __init__( self, + vector_store_connector: VectorStoreConnector, top_k: int = 4, query_rewrite: Optional[QueryRewrite] = None, - rerank: Ranker = None, - vector_store_connector: VectorStoreConnector = None, + rerank: Optional[Ranker] = None, ): - """ + """Create EmbeddingRetriever. + Args: top_k (int): top k query_rewrite (Optional[QueryRewrite]): query rewrite @@ -28,7 +30,6 @@ def __init__( vector_store_connector (VectorStoreConnector): vector store connector Examples: - .. code-block:: python from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -66,8 +67,10 @@ def __init__( def _retrieve(self, query: str) -> List[Chunk]: """Retrieve knowledge chunks. + Args: query (str): query text + Return: List[Chunk]: list of chunks """ @@ -76,14 +79,16 @@ def _retrieve(self, query: str) -> List[Chunk]: self._vector_store_connector.similar_search(query, self._top_k) for query in queries ] - candidates = reduce(lambda x, y: x + y, candidates) - return candidates + res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) + return res_candidates def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + Return: List[Chunk]: list of chunks with score """ @@ -94,14 +99,18 @@ def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk ) for query in queries ] - candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score) - candidates_with_score = self._rerank.rank(candidates_with_score) - return candidates_with_score + new_candidates_with_score = cast( + List[Chunk], reduce(lambda x, y: x + y, candidates_with_score) + ) + new_candidates_with_score = self._rerank.rank(new_candidates_with_score) + return new_candidates_with_score async def _aretrieve(self, query: str) -> List[Chunk]: """Retrieve knowledge chunks. + Args: query (str): query text + Return: List[Chunk]: list of chunks """ @@ -115,16 +124,18 @@ async def _aretrieve(self, query: str) -> List[Chunk]: ) queries.extend(new_queries) candidates = [self._similarity_search(query) for query in queries] - candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) - return candidates + new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) + return new_candidates async def _aretrieve_with_score( self, query: str, score_threshold: float ) -> List[Chunk]: """Retrieve knowledge chunks with score. + Args: query (str): query text score_threshold (float): score threshold + Return: List[Chunk]: list of chunks with score """ @@ -154,10 +165,12 @@ async def _aretrieve_with_score( self._similarity_search_with_score(query, score_threshold) for query in queries ] - candidates_with_score = await run_async_tasks( + res_candidates_with_score = await run_async_tasks( tasks=candidates_with_score, concurrency_limit=1 ) - candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score) + new_candidates_with_score = cast( + List[Chunk], reduce(lambda x, y: x + y, res_candidates_with_score) + ) with root_tracer.start_span( "EmbeddingRetriever.rerank", @@ -167,8 +180,8 @@ async def _aretrieve_with_score( "rerank_cls": self._rerank.__class__.__name__, }, ): - candidates_with_score = self._rerank.rank(candidates_with_score) - return candidates_with_score + new_candidates_with_score = self._rerank.rank(new_candidates_with_score) + return new_candidates_with_score async def _similarity_search(self, query) -> List[Chunk]: """Similar search.""" @@ -181,7 +194,7 @@ async def _run_async_tasks(self, tasks) -> List[Chunk]: """Run async tasks.""" candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1) candidates = reduce(lambda x, y: x + y, candidates) - return candidates + return cast(List[Chunk], candidates) async def _similarity_search_with_score( self, query, score_threshold diff --git a/dbgpt/rag/retriever/rerank.py b/dbgpt/rag/retriever/rerank.py index a2a97f4dd..cd29db551 100644 --- a/dbgpt/rag/retriever/rerank.py +++ b/dbgpt/rag/retriever/rerank.py @@ -1,15 +1,19 @@ -from abc import ABC -from typing import List, Optional +"""Rerank module for RAG retriever.""" + +from abc import ABC, abstractmethod +from typing import Callable, List, Optional from dbgpt.rag.chunk import Chunk +RANK_FUNC = Callable[[List[Chunk]], List[Chunk]] + class Ranker(ABC): - """Base Ranker""" + """Base Ranker.""" + + def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None) -> None: + """Create abstract base ranker. - def __init__(self, topk: int, rank_fn: Optional[callable] = None) -> None: - """ - abstract base ranker Args: topk: int rank_fn: Optional[callable] @@ -17,19 +21,23 @@ def __init__(self, topk: int, rank_fn: Optional[callable] = None) -> None: self.topk = topk self.rank_fn = rank_fn + @abstractmethod def rank(self, candidates_with_scores: List) -> List[Chunk]: - """rank algorithm implementation return topk documents by candidates similarity score + """Return top k chunks after ranker. + + Rank algorithm implementation return topk documents by candidates + similarity score + Args: candidates_with_scores: List[Tuple] topk: int + Return: List[Document] """ - pass - def _filter(self, candidates_with_scores: List) -> List[Chunk]: - """filter duplicate candidates documents""" + """Filter duplicate candidates documents.""" candidates_with_scores = sorted( candidates_with_scores, key=lambda x: x.score, reverse=True ) @@ -43,18 +51,22 @@ def _filter(self, candidates_with_scores: List) -> List[Chunk]: class DefaultRanker(Ranker): - """Default Ranker""" + """Default Ranker.""" - def __init__(self, topk: int, rank_fn: Optional[callable] = None): + def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None): + """Create Default Ranker with topk and rank_fn.""" super().__init__(topk, rank_fn) def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]: - """Default rank algorithm implementation - return topk documents by candidates similarity score + """Return top k chunks after ranker. + + Return top k documents by candidates similarity score + Args: candidates_with_scores: List[Tuple] + Return: - List[Document] + List[Chunk]: List of top k documents """ candidates_with_scores = self._filter(candidates_with_scores) if self.rank_fn is not None: @@ -67,14 +79,21 @@ def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]: class RRFRanker(Ranker): - """RRF(Reciprocal Rank Fusion) Ranker""" + """RRF(Reciprocal Rank Fusion) Ranker.""" - def __init__(self, topk: int, rank_fn: Optional[callable] = None): + def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None): + """RRF rank algorithm implementation.""" super().__init__(topk, rank_fn) def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]: - """RRF rank algorithm implementation - This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a method for combining multiple result sets with different relevance indicators into a single result set. RRF requires no tuning, and the different relevance indicators do not have to be related to each other to achieve high-quality results. + """RRF rank algorithm implementation. + + This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a + method for combining multiple result sets with different relevance indicators + into a single result set. RRF requires no tuning, and the different relevance + indicators do not have to be related to each other to achieve high-quality + results. + RRF uses the following formula to determine the score for ranking each document: score = 0.0 for q in queries: diff --git a/dbgpt/rag/retriever/rewrite.py b/dbgpt/rag/retriever/rewrite.py index b85d01a87..d80573b49 100644 --- a/dbgpt/rag/retriever/rewrite.py +++ b/dbgpt/rag/retriever/rewrite.py @@ -1,36 +1,41 @@ +"""Query rewrite.""" from typing import List, Optional from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest REWRITE_PROMPT_TEMPLATE_EN = """ -Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: '": +Based on the given context {context}, Generate {nums} search queries related to: +{original_query}, Provide following comma-separated format: 'queries: '": "original query:{original_query}\n" "queries:" """ -REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:' +REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询, +这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有 +生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:' "original_query:{original_query}\n" "queries:" """ class QueryRewrite: - """ + """Query rewrite. + query reinforce, include query rewrite, query correct """ def __init__( self, - model_name: str = None, - llm_client: Optional[LLMClient] = None, + model_name: str, + llm_client: LLMClient, language: Optional[str] = "en", ) -> None: - """query rewrite + """Create QueryRewrite with model_name, llm_client, language. + Args: - - query: (str), user query - - model_name: (str), llm model name - - llm_client: (Optional[LLMClient]) - - language: (Optional[str]), language + model_name(str): model name + llm_client(LLMClient, optional): llm client + language(str, optional): language """ self._model_name = model_name self._llm_client = llm_client @@ -44,11 +49,13 @@ def __init__( async def rewrite( self, origin_query: str, context: Optional[str], nums: Optional[int] = 1 ) -> List[str]: - """query rewrite + """Query rewrite. + Args: origin_query: str original query context: Optional[str] context nums: Optional[int] rewrite nums + Returns: queries: List[str] """ @@ -75,13 +82,16 @@ async def rewrite( print(f"rewrite queries: {new_queries}") return new_queries - def correct(self) -> List[str]: + def correct(self) -> List[str] | None: + """Query correct.""" pass def _parse_llm_output(self, output: str) -> List[str]: - """parse llm output + """Parse llm output. + Args: output: str + Returns: output: List[str] """ diff --git a/dbgpt/rag/schemalinker/__init__.py b/dbgpt/rag/schemalinker/__init__.py index e69de29bb..ab1166d52 100644 --- a/dbgpt/rag/schemalinker/__init__.py +++ b/dbgpt/rag/schemalinker/__init__.py @@ -0,0 +1 @@ +"""Module of SchemaLinker.""" diff --git a/dbgpt/rag/schemalinker/base_linker.py b/dbgpt/rag/schemalinker/base_linker.py index 592ff997e..e74820884 100644 --- a/dbgpt/rag/schemalinker/base_linker.py +++ b/dbgpt/rag/schemalinker/base_linker.py @@ -1,3 +1,5 @@ +"""Base Linker.""" + from abc import ABC, abstractmethod from typing import List @@ -6,7 +8,8 @@ class BaseSchemaLinker(ABC): """Base Linker.""" def schema_linking(self, query: str) -> List: - """ + """Query schema info. + Args: query (str): query text Returns: @@ -15,7 +18,8 @@ def schema_linking(self, query: str) -> List: return self._schema_linking(query) def schema_linking_with_vector_db(self, query: str) -> List: - """ + """Query schema info with vector db. + Args: query (str): query text Returns: @@ -24,7 +28,8 @@ def schema_linking_with_vector_db(self, query: str) -> List: return self._schema_linking_with_vector_db(query) async def schema_linking_with_llm(self, query: str) -> List: - """ " + """Query schema info with llm. + Args: query(str): query text Returns: @@ -34,7 +39,8 @@ async def schema_linking_with_llm(self, query: str) -> List: @abstractmethod def _schema_linking(self, query: str) -> List: - """ + """Get DB schema info. + Args: query (str): query text Returns: @@ -43,7 +49,8 @@ def _schema_linking(self, query: str) -> List: @abstractmethod def _schema_linking_with_vector_db(self, query: str) -> List: - """ + """Query schema info with vector db. + Args: query (str): query text Returns: @@ -52,7 +59,8 @@ def _schema_linking_with_vector_db(self, query: str) -> List: @abstractmethod async def _schema_linking_with_llm(self, query: str) -> List: - """ + """Query schema info with llm. + Args: query (str): query text Returns: diff --git a/dbgpt/rag/schemalinker/schema_linking.py b/dbgpt/rag/schemalinker/schema_linking.py index 680f08245..748b3a928 100644 --- a/dbgpt/rag/schemalinker/schema_linking.py +++ b/dbgpt/rag/schemalinker/schema_linking.py @@ -1,5 +1,7 @@ +"""SchemaLinking by LLM.""" + from functools import reduce -from typing import List, Optional +from typing import List, Optional, cast from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest from dbgpt.datasource.rdbms.base import RDBMSDatabase @@ -9,32 +11,41 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.util.chat_util import run_async_tasks -INSTRUCTION = ( - "You need to filter out the most relevant database table schema information (it may be a single " - "table or multiple tables) required to generate the SQL of the question query from the given " - "database schema information. First, I will show you an example of an instruction followed by " - "the correct schema response. Then, I will give you a new instruction, and you should write " - "the schema response that appropriately completes the request.\n### Example1 Instruction:\n" - "['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 " - "Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']" - "\n###New Instruction:\n{}" -) +INSTRUCTION = """ +You need to filter out the most relevant database table schema information (it may be a + single table or multiple tables) required to generate the SQL of the question query + from the given database schema information. First, I will show you an example of an + instruction followed by the correct schema response. Then, I will give you a new + instruction, and you should write the schema response that appropriately completes the + request. + +### Example1 Instruction: +['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)'] +### Example1 Input: +Find the age of student table +### Example1 Response: +['student(id, name, age, info)'] +###New Instruction: +{} +""" + INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:" class SchemaLinking(BaseSchemaLinker): - """SchemaLinking by LLM""" + """SchemaLinking by LLM.""" def __init__( self, + connection: RDBMSDatabase, + model_name: str, + llm: LLMClient, top_k: int = 5, - connection: Optional[RDBMSDatabase] = None, - llm: Optional[LLMClient] = None, - model_name: Optional[str] = None, vector_store_connector: Optional[VectorStoreConnector] = None, **kwargs ): - """ + """Create the schema linking instance. + Args: connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. llm (Optional[LLMClient]): base llm @@ -47,20 +58,21 @@ def __init__( self._vector_store_connector = vector_store_connector def _schema_linking(self, query: str) -> List: - """get all db schema info""" + """Get all db schema info.""" table_summaries = _parse_db_summary(self._connection) chunks = [Chunk(content=table_summary) for table_summary in table_summaries] chunks_content = [chunk.content for chunk in chunks] return chunks_content - def _schema_linking_with_vector_db(self, query: str) -> List: + def _schema_linking_with_vector_db(self, query: str) -> List[Chunk]: queries = [query] + if not self._vector_store_connector: + raise ValueError("Vector store connector is not provided.") candidates = [ self._vector_store_connector.similar_search(query, self._top_k) for query in queries ] - candidates = reduce(lambda x, y: x + y, candidates) - return candidates + return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) async def _schema_linking_with_llm(self, query: str) -> List: chunks_content = self.schema_linking(query) diff --git a/dbgpt/rag/summary/__init__.py b/dbgpt/rag/summary/__init__.py index e69de29bb..d57ffeddd 100644 --- a/dbgpt/rag/summary/__init__.py +++ b/dbgpt/rag/summary/__init__.py @@ -0,0 +1,18 @@ +"""Module for summary related classes and functions.""" +from .db_summary import ( # noqa: F401 + DBSummary, + FieldSummary, + IndexSummary, + TableSummary, +) +from .db_summary_client import DBSummaryClient # noqa: F401 +from .rdbms_db_summary import RdbmsSummary # noqa: F401 + +__all__ = [ + "DBSummary", + "FieldSummary", + "IndexSummary", + "TableSummary", + "DBSummaryClient", + "RdbmsSummary", +] diff --git a/dbgpt/rag/summary/db_summary.py b/dbgpt/rag/summary/db_summary.py index 86306a31d..998953117 100644 --- a/dbgpt/rag/summary/db_summary.py +++ b/dbgpt/rag/summary/db_summary.py @@ -1,31 +1,48 @@ +"""Summary classes for database, table, field and index.""" +from typing import Dict, Iterable, List, Optional, Tuple + + class DBSummary: - def __init__(self, name): + """Database summary class.""" + + def __init__(self, name: str): + """Create a new DBSummary.""" self.name = name - self.summary = None - self.tables = [] - self.metadata = str + self.summary: Optional[str] = None + self.tables: Iterable[str] = [] + self.metadata: Optional[str] = None - def get_summary(self): + def get_summary(self) -> Optional[str]: + """Get the summary.""" return self.summary class TableSummary: - def __init__(self, name): + """Table summary class.""" + + def __init__(self, name: str): + """Create a new TableSummary.""" self.name = name - self.summary = None - self.fields = [] - self.indexes = [] + self.summary: Optional[str] = None + self.fields: List[Tuple] = [] + self.indexes: List[Dict] = [] class FieldSummary: - def __init__(self, name): + """Field summary class.""" + + def __init__(self, name: str): + """Create a new FieldSummary.""" self.name = name self.summary = None self.data_type = None class IndexSummary: - def __init__(self, name): + """Index summary class.""" + + def __init__(self, name: str): + """Create a new IndexSummary.""" self.name = name - self.summary = None - self.bind_fields = [] + self.summary: Optional[str] = None + self.bind_fields: List[str] = [] diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 359d62cff..6cb725b0b 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -1,3 +1,5 @@ +"""DBSummaryClient class.""" + import logging import traceback @@ -12,26 +14,31 @@ class DBSummaryClient: - """DB Summary client, provide db_summary_embedding(put db profile and table profile summary into vector store) - , get_similar_tables method(get user query related tables info) + """The client for DBSummary. + + DB Summary client, provide db_summary_embedding(put db profile and table profile + summary into vector store), get_similar_tables method(get user query related tables + info) + Args: - system_app (SystemApp): Main System Application class that manages the lifecycle and registration of components.. + system_app (SystemApp): Main System Application class that manages the + lifecycle and registration of components.. """ def __init__(self, system_app: SystemApp): + """Create a new DBSummaryClient.""" self.system_app = system_app from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory - embedding_factory = self.system_app.get_component( - "embedding_factory", EmbeddingFactory + embedding_factory: EmbeddingFactory = self.system_app.get_component( + "embedding_factory", component_type=EmbeddingFactory ) self.embeddings = embedding_factory.create( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) def db_summary_embedding(self, dbname, db_type): - """put db profile and table profile summary into vector store""" - + """Put db profile and table profile summary into vector store.""" db_summary_client = RdbmsSummary(dbname, db_type) self.init_db_profile(db_summary_client, dbname) @@ -39,8 +46,7 @@ def db_summary_embedding(self, dbname, db_type): logger.info("db summary embedding success") def get_db_summary(self, dbname, query, topk): - """get user query related tables info""" - + """Get user query related tables info.""" from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -60,7 +66,7 @@ def get_db_summary(self, dbname, query, topk): return ans def init_db_summary(self): - """init db summary""" + """Initialize db summary profile.""" db_mange = CFG.LOCAL_DB_MANAGE dbs = db_mange.get_db_list() for item in dbs: @@ -69,11 +75,13 @@ def init_db_summary(self): except Exception as e: message = traceback.format_exc() logger.warn( - f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, detail: {message}' + f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, ' + f"detail: {message}" ) def init_db_profile(self, db_summary_client, dbname): - """db profile initialization + """Initialize db summary profile. + Args: db_summary_client(DBSummaryClient): DB Summary Client dbname(str): dbname diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index 0b7333aa7..224dc3492 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,3 +1,5 @@ +"""Summary for rdbms database.""" + from typing import List from dbgpt._private.config import Config @@ -9,21 +11,28 @@ class RdbmsSummary(DBSummary): """Get rdbms db table summary template. - summary example: - table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment is {table_comment}) + + Summary example: + table_name(column1(column1 comment),column2(column2 comment), + column3(column3 comment) and index keys, and table comment is {table_comment}) """ - def __init__(self, name, type): + def __init__(self, name: str, type: str): + """Create a new RdbmsSummary.""" self.name = name self.type = type self.summary_template = "{table_name}({columns})" self.tables = {} - self.tables_info = [] - self.vector_tables_info = [] + # self.tables_info = [] + # self.vector_tables_info = [] + if not CFG.LOCAL_DB_MANAGE: + raise ValueError("Local db manage is not initialized.") + # TODO: Don't use the global variable. self.db = CFG.LOCAL_DB_MANAGE.get_connect(name) - self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( + self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, + collation:{collation}""".format( users=self.db.get_users(), grant=self.db.get_grants(), charset=self.db.get_charset(), @@ -36,8 +45,10 @@ def __init__(self, name, type): def get_table_summary(self, table_name): """Get table summary for table. + example: - table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment}) + table_name(column1(column1 comment),column2(column2 comment), + column3(column3 comment) and index keys, and table comment: {table_comment}) """ return _parse_table_summary(self.db, self.summary_template, table_name) @@ -74,7 +85,8 @@ def _parse_table_summary( table_name (str): table name Examples: - table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment}) + table_name(column1(column1 comment),column2(column2 comment), + column3(column3 comment) and index keys, and table comment: {table_comment}) """ columns = [] for column in conn.get_columns(table_name): diff --git a/dbgpt/rag/text_splitter/__init__.py b/dbgpt/rag/text_splitter/__init__.py index e69de29bb..3bf1ad075 100644 --- a/dbgpt/rag/text_splitter/__init__.py +++ b/dbgpt/rag/text_splitter/__init__.py @@ -0,0 +1,23 @@ +"""Text splitter module.""" + +from .pre_text_splitter import PreTextSplitter # noqa: F401 +from .text_splitter import ( # noqa: F401 + CharacterTextSplitter, + MarkdownHeaderTextSplitter, + PageTextSplitter, + ParagraphTextSplitter, + SeparatorTextSplitter, + SpacyTextSplitter, + TextSplitter, +) + +__ALL__ = [ + "PreTextSplitter", + "CharacterTextSplitter", + "MarkdownHeaderTextSplitter", + "PageTextSplitter", + "ParagraphTextSplitter", + "SeparatorTextSplitter", + "SpacyTextSplitter", + "TextSplitter", +] diff --git a/dbgpt/rag/text_splitter/pre_text_splitter.py b/dbgpt/rag/text_splitter/pre_text_splitter.py index 32178dafe..3c43cf8d2 100644 --- a/dbgpt/rag/text_splitter/pre_text_splitter.py +++ b/dbgpt/rag/text_splitter/pre_text_splitter.py @@ -1,3 +1,4 @@ +"""Pre text splitter.""" from typing import Iterable, List from dbgpt.rag.chunk import Chunk, Document @@ -7,8 +8,8 @@ def _single_document_split( document: Document, pre_separator: str ) -> Iterable[Document]: - content = document.content - for i, content in enumerate(content.split(pre_separator)): + origin_content = document.content + for i, content in enumerate(origin_content.split(pre_separator)): metadata = document.metadata.copy() if "source" in metadata: metadata["source"] = metadata["source"] + "_pre_split_" + str(i) @@ -16,10 +17,11 @@ def _single_document_split( class PreTextSplitter(TextSplitter): - """Split text by pre separator""" + """Split text by pre separator.""" def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter): - """Initialize with Knowledge arguments. + """Create the pre text splitter instance. + Args: pre_separator: pre separator text_splitter_impl: text splitter impl @@ -28,11 +30,11 @@ def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter): self._impl = text_splitter_impl def split_text(self, text: str, **kwargs) -> List[str]: - """Split text by pre separator""" + """Split text by pre separator.""" return self._impl.split_text(text) def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]: - """Split documents by pre separator""" + """Split documents by pre separator.""" def generator() -> Iterable[Document]: for doc in documents: diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index 75d74fb41..8516dc279 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -1,17 +1,9 @@ +"""Text splitter module for splitting text into chunks.""" + import copy import logging from abc import ABC, abstractmethod -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - TypedDict, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Optional, TypedDict, Union, cast from dbgpt.rag.chunk import Chunk, Document @@ -20,7 +12,9 @@ class TextSplitter(ABC): """Interface for splitting text into chunks. - Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py + + Refer to `Langchain Text Splitter `_ """ outgoing_edges = 1 @@ -30,10 +24,12 @@ def __init__( chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, - filters: list = [], + filters=None, separator: str = "", ): """Create a new TextSplitter.""" + if filters is None: + filters = [] if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " @@ -60,7 +56,7 @@ def create_documents( _metadatas = metadatas or [{}] * len(texts) chunks = [] for i, text in enumerate(texts): - for chunk in self.split_text(text, separator, **kwargs): + for chunk in self.split_text(text, separator=separator, **kwargs): new_doc = Chunk(content=chunk, metadata=copy.deepcopy(_metadatas[i])) chunks.append(new_doc) return chunks @@ -85,8 +81,8 @@ def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str] def _merge_splits( self, - splits: Iterable[str], - separator: str, + splits: Iterable[str | dict], + separator: Optional[str] = None, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, ) -> List[str]: @@ -103,7 +99,8 @@ def _merge_splits( docs = [] current_doc: List[str] = [] total = 0 - for d in splits: + for s in splits: + d = cast(str, s) _len = self._length_function(d) if ( total + _len + (separator_len if len(current_doc) > 0 else 0) @@ -138,6 +135,7 @@ def _merge_splits( return docs def clean(self, documents: List[dict], filters: List[str]): + """Clean the documents.""" for special_character in filters: for doc in documents: doc["content"] = doc["content"].replace(special_character, "") @@ -146,12 +144,13 @@ def clean(self, documents: List[dict], filters: List[str]): def run( # type: ignore self, documents: Union[dict, List[dict]], - meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, # type: ignore + meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, separator: Optional[str] = None, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, filters: Optional[List[str]] = None, ): + """Run the text splitter.""" if separator is None: separator = self._separator if chunk_size is None: @@ -203,12 +202,16 @@ def run( # type: ignore class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters. - Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py + + Refer to `Langchain Test Splitter `_ """ - def __init__(self, separator: str = "\n\n", filters: list = [], **kwargs: Any): + def __init__(self, separator: str = "\n\n", filters=None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) + if filters is None: + filters = [] self._separator = separator self._filter = filters @@ -228,9 +231,12 @@ def split_text( class RecursiveCharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters. + Recursively tries to split by different characters to find one that works. - Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py + + Refer to `Langchain Test Splitter `_ """ def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any): @@ -287,7 +293,9 @@ def split_text( class SpacyTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using Spacy. - Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py + + Refer to `Langchain Test Splitter `_ """ def __init__(self, pipeline: str = "zh_core_web_sm", **kwargs: Any) -> None: @@ -301,7 +309,7 @@ def __init__(self, pipeline: str = "zh_core_web_sm", **kwargs: Any) -> None: ) try: self._tokenizer = spacy.load(pipeline) - except: + except Exception: spacy.cli.download(pipeline) self._tokenizer = spacy.load(pipeline) @@ -332,23 +340,18 @@ class LineType(TypedDict): class MarkdownHeaderTextSplitter(TextSplitter): """Implementation of splitting markdown files based on specified headers. - Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py + + Refer to `Langchain Text Splitter `_ """ outgoing_edges = 1 def __init__( self, - headers_to_split_on: List[Tuple[str, str]] = [ - ("#", "Header 1"), - ("##", "Header 2"), - ("###", "Header 3"), - ("####", "Header 4"), - ("#####", "Header 5"), - ("######", "Header 6"), - ], + headers_to_split_on=None, return_each_line: bool = False, - filters: list = [], + filters=None, chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, @@ -361,6 +364,17 @@ def __init__( return_each_line: Return each line w/ associated headers """ # Output line-by-line or aggregated into chunks w/ common headers + if headers_to_split_on is None: + headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), + ("####", "Header 4"), + ("#####", "Header 5"), + ("######", "Header 6"), + ] + if filters is None: + filters = [] self.return_each_line = return_each_line self._chunk_size = chunk_size # Given the headers we want to split on, @@ -392,7 +406,8 @@ def create_documents( return chunks def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Chunk]: - """Combine lines with common metadata into chunks + """Aggregate lines into chunks based on common metadata. + Args: lines: Line of text / associated header metadata """ @@ -417,16 +432,22 @@ def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Chunk]: for chunk in aggregated_chunks ] - def split_text( + def split_text( # type: ignore self, text: str, separator: Optional[str] = None, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, + **kwargs, ) -> List[Chunk]: - """Split markdown file + """Split incoming text and return chunks. + Args: - text: Markdown file""" + text(str): The input text + separator(str): The separator to use for splitting the text + chunk_size(int): The size of each chunk + chunk_overlap(int): The overlap between chunks + """ if separator is None: separator = self._separator if chunk_size is None: @@ -527,6 +548,7 @@ def split_text( ] def clean(self, documents: List[dict], filters: Optional[List[str]] = None): + """Clean the documents.""" if filters is None: filters = self._filter for special_character in filters: @@ -534,7 +556,7 @@ def clean(self, documents: List[dict], filters: Optional[List[str]] = None): doc["content"] = doc["content"].replace(special_character, "") return documents - def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: + def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]: text = separator.join(docs) text = text.strip() if text == "": @@ -544,10 +566,10 @@ def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: def _merge_splits( self, - documents: List[dict], + documents: Iterable[str | dict], separator: Optional[str] = None, chunk_size: Optional[int] = None, - chunk_overlap: [int] = None, + chunk_overlap: Optional[int] = None, ) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. @@ -562,14 +584,15 @@ def _merge_splits( docs = [] current_doc: List[str] = [] total = 0 - for doc in documents: - if doc["metadata"] != {}: + for _doc in documents: + dict_doc = cast(dict, _doc) + if dict_doc["metadata"] != {}: head = sorted( - doc["metadata"].items(), key=lambda x: x[0], reverse=True + dict_doc["metadata"].items(), key=lambda x: x[0], reverse=True )[0][1] - d = head + separator + doc["page_content"] + d = head + separator + dict_doc["page_content"] else: - d = doc["page_content"] + d = dict_doc["page_content"] _len = self._length_function(d) if ( total + _len + (separator_len if len(current_doc) > 0 else 0) @@ -607,11 +630,12 @@ def run( self, documents: Union[dict, List[dict]], meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, - filters: Optional[List[str]] = None, + separator: Optional[str] = None, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, - separator: Optional[str] = None, + filters: Optional[List[str]] = None, ): + """Run the text splitter.""" if filters is None: filters = self._filter if chunk_size is None: @@ -627,11 +651,10 @@ def run( document["content"], separator, chunk_size, chunk_overlap ) for i, txt in enumerate(text_splits): - doc = {} - doc["content"] = txt + doc = {"content": txt} if "meta" not in doc.keys() or doc["meta"] is None: - doc["meta"] = {} + doc["meta"] = {} # type: ignore doc["meta"]["_split_id"] = i ret.append(doc) @@ -640,11 +663,10 @@ def run( documents["content"], separator, chunk_size, chunk_overlap ) for i, txt in enumerate(text_splits): - doc = {} - doc["content"] = txt + doc = {"content": txt} if "meta" not in doc.keys() or doc["meta"] is None: - doc["meta"] = {} + doc["meta"] = {} # type: ignore doc["meta"]["_split_id"] = i ret.append(doc) @@ -662,9 +684,10 @@ class ParagraphTextSplitter(CharacterTextSplitter): def __init__( self, separator="\n", - chunk_size: Optional[int] = 0, - chunk_overlap: Optional[int] = 0, + chunk_size: int = 0, + chunk_overlap: int = 0, ): + """Create a new ParagraphTextSplitter.""" self._separator = separator if self._separator is None: self._separator = "\n" @@ -675,16 +698,19 @@ def __init__( def split_text( self, text: str, separator: Optional[str] = "\n", **kwargs ) -> List[str]: + """Split incoming text and return chunks.""" paragraphs = text.strip().split(self._separator) paragraphs = [p.strip() for p in paragraphs if p.strip() != ""] return paragraphs class SeparatorTextSplitter(CharacterTextSplitter): - """SeparatorTextSplitter""" + """The SeparatorTextSplitter class.""" - def __init__(self, separator: str = "\n", filters: list = [], **kwargs: Any): + def __init__(self, separator: str = "\n", filters=None, **kwargs: Any): """Create a new TextSplitter.""" + if filters is None: + filters = [] self._merge = kwargs.pop("enable_merge") or False super().__init__(**kwargs) self._separator = separator @@ -706,11 +732,13 @@ def split_text( class PageTextSplitter(TextSplitter): - """PageTextSplitter""" + """The PageTextSplitter class.""" - def __init__(self, separator: str = "\n\n", filters: list = [], **kwargs: Any): + def __init__(self, separator: str = "\n\n", filters=None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) + if filters is None: + filters = [] self._separator = separator self._filter = filters @@ -718,7 +746,7 @@ def split_text( self, text: str, separator: Optional[str] = None, **kwargs ) -> List[str]: """Split incoming text and return chunks.""" - return text + return [text] def create_documents( self, diff --git a/dbgpt/rag/text_splitter/token_splitter.py b/dbgpt/rag/text_splitter/token_splitter.py index f00be63e4..5ae06967c 100644 --- a/dbgpt/rag/text_splitter/token_splitter.py +++ b/dbgpt/rag/text_splitter/token_splitter.py @@ -1,8 +1,7 @@ """Token splitter.""" from typing import Callable, List, Optional -from pydantic import BaseModel, Field, PrivateAttr - +from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr from dbgpt.util.global_helper import globals_helper from dbgpt.util.splitter_utils import split_by_char, split_by_sep @@ -45,9 +44,11 @@ def __init__( tokenizer: Optional[Callable] = None, # callback_manager: Optional[CallbackManager] = None, separator: str = " ", - backup_separators: Optional[List[str]] = ["\n"], + backup_separators=None, ): """Initialize with parameters.""" + if backup_separators is None: + backup_separators = ["\n"] if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " @@ -70,6 +71,7 @@ def __init__( @classmethod def class_name(cls) -> str: + """Return the class name.""" return "TokenTextSplitter" def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: diff --git a/dbgpt/util/speech/say.py b/dbgpt/util/speech/say.py index b362d84b7..dd4ca8414 100644 --- a/dbgpt/util/speech/say.py +++ b/dbgpt/util/speech/say.py @@ -36,7 +36,7 @@ def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]: default_voice_engine = GTTSVoice() if config.elevenlabs_api_key: voice_engine = ElevenLabsSpeech() - elif config.use_mac_os_tts == "True": + elif config.use_mac_os_tts: voice_engine = MacOSTTS() elif config.use_brian_tts == "True": voice_engine = BrianSpeech() diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index c101163aa..71efdfac5 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -2,7 +2,7 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect -from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory +from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector diff --git a/examples/rag/embedding_rag_example.py b/examples/rag/embedding_rag_example.py index de493b59a..d21d662f6 100644 --- a/examples/rag/embedding_rag_example.py +++ b/examples/rag/embedding_rag_example.py @@ -3,8 +3,8 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH from dbgpt.rag.chunk_manager import ChunkParameters -from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory -from dbgpt.rag.knowledge.factory import KnowledgeFactory +from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.rag.knowledge import KnowledgeFactory from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector diff --git a/examples/rag/rag_embedding_api_example.py b/examples/rag/rag_embedding_api_example.py index b5f669ed8..b7014cdf6 100644 --- a/examples/rag/rag_embedding_api_example.py +++ b/examples/rag/rag_embedding_api_example.py @@ -29,7 +29,7 @@ from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.embedding import OpenAPIEmbeddings -from dbgpt.rag.knowledge.factory import KnowledgeFactory +from dbgpt.rag.knowledge import KnowledgeFactory from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector diff --git a/examples/rag/rewrite_rag_example.py b/examples/rag/rewrite_rag_example.py index a74f60b87..b60f788f9 100644 --- a/examples/rag/rewrite_rag_example.py +++ b/examples/rag/rewrite_rag_example.py @@ -1,10 +1,5 @@ -import asyncio -import os - -from dbgpt.model.proxy import OpenAILLMClient -from dbgpt.rag.retriever.rewrite import QueryRewrite - """Query rewrite example. + pre-requirements: 1. install openai python sdk ``` @@ -15,7 +10,7 @@ export OPENAI_API_KEY={your_openai_key} export OPENAI_API_BASE={your_openai_base} ``` - or + or ``` import os os.environ["OPENAI_API_KEY"] = {your_openai_key} @@ -26,6 +21,11 @@ python examples/rag/rewrite_rag_example.py """ +import asyncio + +from dbgpt.model.proxy import OpenAILLMClient +from dbgpt.rag.retriever import QueryRewrite + async def main(): query = "compare steve curry and lebron james" diff --git a/examples/rag/simple_dbschema_retriever_example.py b/examples/rag/simple_dbschema_retriever_example.py index 72c2dfcd3..354c72d78 100644 --- a/examples/rag/simple_dbschema_retriever_example.py +++ b/examples/rag/simple_dbschema_retriever_example.py @@ -1,19 +1,3 @@ -import os -from typing import Dict, List - -from pydantic import BaseModel, Field - -from dbgpt._private.config import Config -from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH -from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect -from dbgpt.rag.chunk import Chunk -from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory -from dbgpt.rag.operators.db_schema import DBSchemaRetrieverOperator -from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator -from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig -from dbgpt.storage.vector_store.connector import VectorStoreConnector - """AWEL: Simple rag db schema embedding operator example if you not set vector_store_connector, it will return all tables schema in database. @@ -38,6 +22,21 @@ --data '{"query": "what is user name?"}' """ +import os +from typing import Dict, List + +from pydantic import BaseModel, Field + +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH +from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.rag.operators import DBSchemaRetrieverOperator +from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector CFG = Config() diff --git a/examples/rag/simple_rag_embedding_example.py b/examples/rag/simple_rag_embedding_example.py index 86f248153..56e5f959f 100644 --- a/examples/rag/simple_rag_embedding_example.py +++ b/examples/rag/simple_rag_embedding_example.py @@ -1,3 +1,16 @@ +"""AWEL: Simple rag embedding operator example. + + Examples: + pre-requirements: + python examples/awel/simple_rag_embedding_example.py + ..code-block:: shell + curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "url": "https://docs.dbgpt.site/docs/awel" + }' +""" + import os from typing import Dict, List @@ -6,26 +19,13 @@ from dbgpt._private.config import Config from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, MapOperator -from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory -from dbgpt.rag.knowledge.base import KnowledgeType -from dbgpt.rag.operators.knowledge import KnowledgeOperator +from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.rag.knowledge import KnowledgeType +from dbgpt.rag.operators import KnowledgeOperator from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector -"""AWEL: Simple rag embedding operator example - - Examples: - pre-requirements: - python examples/awel/simple_rag_embedding_example.py - ..code-block:: shell - curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \ - --header 'Content-Type: application/json' \ - --data-raw '{ - "url": "https://docs.dbgpt.site/docs/awel" - }' -""" - CFG = Config() diff --git a/examples/rag/simple_rag_retriever_example.py b/examples/rag/simple_rag_retriever_example.py index b9c7ca97f..19ab78666 100644 --- a/examples/rag/simple_rag_retriever_example.py +++ b/examples/rag/simple_rag_retriever_example.py @@ -1,26 +1,8 @@ -import asyncio -import os -from typing import Dict, List - -from pydantic import BaseModel, Field - -from dbgpt._private.config import Config -from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH -from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator -from dbgpt.model.proxy import OpenAILLMClient -from dbgpt.rag.chunk import Chunk -from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory -from dbgpt.rag.operators.embedding import EmbeddingRetrieverOperator -from dbgpt.rag.operators.rerank import RerankOperator -from dbgpt.rag.operators.rewrite import QueryRewriteOperator -from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig -from dbgpt.storage.vector_store.connector import VectorStoreConnector - """AWEL: Simple rag embedding operator example pre-requirements: 1. install openai python sdk - + ``` pip install openai ``` @@ -31,8 +13,8 @@ ``` 3. make sure you have vector store. if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py - - + + ensure your embedding model in DB-GPT/models/. Examples: @@ -44,6 +26,25 @@ }' """ +import os +from typing import Dict, List + +from pydantic import BaseModel, Field + +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH +from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator +from dbgpt.model.proxy import OpenAILLMClient +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.rag.operators import ( + EmbeddingRetrieverOperator, + QueryRewriteOperator, + RerankOperator, +) +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + CFG = Config() diff --git a/examples/rag/summary_extractor_example.py b/examples/rag/summary_extractor_example.py index 33b4cc585..385885ef4 100644 --- a/examples/rag/summary_extractor_example.py +++ b/examples/rag/summary_extractor_example.py @@ -1,10 +1,3 @@ -import asyncio - -from dbgpt.model.proxy import OpenAILLMClient -from dbgpt.rag.chunk_manager import ChunkParameters -from dbgpt.rag.knowledge.factory import KnowledgeFactory -from dbgpt.serve.rag.assembler.summary import SummaryAssembler - """Summary extractor example. pre-requirements: 1. install openai python sdk @@ -16,7 +9,7 @@ export OPENAI_API_KEY={your_openai_key} export OPENAI_API_BASE={your_openai_base} ``` - or + or ``` import os os.environ["OPENAI_API_KEY"] = {your_openai_key} @@ -28,6 +21,14 @@ """ +import asyncio + +from dbgpt.model.proxy import OpenAILLMClient +from dbgpt.rag.chunk_manager import ChunkParameters +from dbgpt.rag.knowledge import KnowledgeFactory +from dbgpt.serve.rag.assembler.summary import SummaryAssembler + + async def main(): file_path = "./docs/docs/awel.md" llm_client = OpenAILLMClient() diff --git a/requirements/lint-requirements.txt b/requirements/lint-requirements.txt index 91639ca6f..c8d53cc5d 100644 --- a/requirements/lint-requirements.txt +++ b/requirements/lint-requirements.txt @@ -9,3 +9,6 @@ flake8-simplify==0.19.3 flake8-tidy-imports==4.8.0 isort==5.10.1 pyupgrade==3.1.0 +types-requests +types-beautifulsoup4 +types-Markdown \ No newline at end of file