Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add pylint for DB-GPT rag lib #1267

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
exclude = /tests/
# plugins = pydantic.mypy

[mypy-dbgpt.app.*]
follow_imports = skip

[mypy-dbgpt.datasource.*]
follow_imports = skip

[mypy-dbgpt.storage.*]
follow_imports = skip

[mypy-dbgpt.serve.*]
follow_imports = skip

[mypy-dbgpt.util.*]
follow_imports = skip

[mypy-graphviz.*]
ignore_missing_imports = True

Expand All @@ -17,4 +32,29 @@ ignore_missing_imports = True
[mypy-pydantic.*]
strict_optional = False
ignore_missing_imports = True
follow_imports = skip

[mypy-sentence_transformers.*]
ignore_missing_imports = True

[mypy-InstructorEmbedding.*]
ignore_missing_imports = True

[mypy-llama_index.*]
ignore_missing_imports = True

[mypy-pptx.*]
ignore_missing_imports = True

[mypy-docx.*]
ignore_missing_imports = True

[mypy-markdown.*]
ignore_missing_imports = True

[mypy-auto_gpt_plugin_template.*]
ignore_missing_imports = True

[mypy-spacy.*]
ignore_missing_imports = True
follow_imports = skip
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ fmt: setup ## Format Python code
# TODO: Use flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
# TODO: More package checks with flake8.

.PHONY: fmt-check
Expand All @@ -58,6 +59,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
$(VENV_BIN)/black --check --extend-exclude="examples/notebook" .
$(VENV_BIN)/blackdoc --check dbgpt examples
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
# $(VENV_BIN)/blackdoc --check dbgpt examples
# $(VENV_BIN)/flake8 dbgpt/core/

Expand All @@ -76,6 +78,7 @@ test-doc: $(VENV)/.testenv ## Run doctests
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/
# TODO: More package checks with mypy.

.PHONY: coverage
Expand Down
25 changes: 12 additions & 13 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self) -> None:
if self.zhipu_proxy_api_key:
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
"ZHIPU_MODEL_VERSION"
"ZHIPU_MODEL_VERSION", ""
)

# wenxin
Expand All @@ -74,7 +74,9 @@ def __init__(self) -> None:
os.environ[
"wenxin_proxyllm_proxy_api_secret"
] = self.wenxin_proxy_api_secret
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
os.environ["wenxin_proxyllm_proxyllm_backend"] = (
self.wenxin_model_version or ""
)

# xunfei spark
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
Expand All @@ -84,8 +86,10 @@ def __init__(self) -> None:
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version
os.environ["spark_proxyllm_proxy_api_app_id"] = self.spark_proxy_api_appid
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version or ""
os.environ["spark_proxyllm_proxy_api_app_id"] = (
self.spark_proxy_api_appid or ""
)

# baichuan proxy
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
Expand All @@ -108,12 +112,10 @@ def __init__(self) -> None:
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")

self.use_mac_os_tts = False
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS")
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS", "False") == "True"

self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
self.image_size = int(os.getenv("IMAGE_SIZE", 256))

self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
Expand All @@ -131,10 +133,7 @@ def __init__(self) -> None:

self.prompt_template_registry = PromptTemplateRegistry()
### Related configuration of built-in commands
self.command_registry = []

### Relate configuration of display commands
self.command_dispaly = []
self.command_registry = [] # type: ignore

disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
if disabled_command_categories:
Expand All @@ -151,7 +150,7 @@ def __init__(self) -> None:
### The associated configuration parameters of the plug-in control the loading and use of the plug-in

self.plugins: List["AutoGPTPluginTemplate"] = []
self.plugins_openai = []
self.plugins_openai = [] # type: ignore
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"

self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
Expand Down Expand Up @@ -274,6 +273,6 @@ def __init__(self) -> None:
self.MODEL_CACHE_MAX_MEMORY_MB: int = int(
os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256)
)
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
self.MODEL_CACHE_STORAGE_DISK_DIR: Optional[str] = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)
2 changes: 1 addition & 1 deletion dbgpt/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def register_instance(self, instance: T) -> T:
def get_component(
self,
name: Union[str, ComponentType],
component_type: Type[T],
component_type: Type,
default_component=_EMPTY_DEFAULT_COMPONENT,
or_register_component: Optional[Type[T]] = None,
*args,
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
).create()
).create() # type: ignore
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/datasource/rdbms/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote

Expand Down Expand Up @@ -499,7 +499,7 @@ def get_show_create_table(self, table_name):
ans = cursor.fetchall()
return ans[0][1]

def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/datasource/rdbms/conn_clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple

import sqlparse
from sqlalchemy import MetaData, text
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_columns(self, table_name: str) -> List[Dict]:
for name, column_type, _, _, comment in fields[0]
]

def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self.client

Expand Down
4 changes: 2 additions & 2 deletions dbgpt/datasource/rdbms/conn_doris.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote

Expand Down Expand Up @@ -68,7 +68,7 @@ def get_users(self):
"""Get user info."""
return []

def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
cursor = self.get_session().execute(
text(
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/datasource/rdbms/conn_postgresql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote

Expand Down Expand Up @@ -85,7 +85,7 @@ def get_users(self):
print("postgresql get users error: ", e)
return []

def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/datasource/rdbms/conn_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import tempfile
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple

from sqlalchemy import create_engine, text

Expand Down Expand Up @@ -58,7 +58,7 @@ def get_show_create_table(self, table_name):
ans = cursor.fetchall()
return ans[0][0]

def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
fields = cursor.fetchall()
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/datasource/rdbms/conn_starrocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote

Expand Down Expand Up @@ -68,7 +68,7 @@ def get_users(self):
"""Get user info."""
return []

def get_fields(self, table_name, db_name="database()"):
def get_fields(self, table_name, db_name="database()") -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
if db_name != "database()":
Expand Down
1 change: 1 addition & 0 deletions dbgpt/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module of RAG."""
Loading
Loading