Skip to content

Commit

Permalink
fix(core): Fix not create database bug (eosphoros-ai#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored and Hopshine committed Sep 10, 2024
1 parent 3c231a7 commit 815cf5e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
47 changes: 47 additions & 0 deletions dbgpt/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import threading
import sys
import logging
from typing import Optional
from dataclasses import dataclass, field

Expand All @@ -14,6 +15,8 @@
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)

logger = logging.getLogger(__name__)


def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
Expand Down Expand Up @@ -110,6 +113,8 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
# Try to create database, if failed, will raise exception
_create_mysql_database(db_name, db_url, try_to_create_db)
else:
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
db_url = f"sqlite:///{sqlite_db_path}"
Expand All @@ -124,6 +129,48 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
return default_meta_data_path


def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = False):
"""Create mysql database if not exists
Args:
db_name (str): The database name
db_url (str): The database url, include host, port, user, password and database name
try_to_create_db (bool, optional): Whether to try to create database. Defaults to False.
Raises:
Exception: Raise exception if database operation failed
"""
from sqlalchemy import create_engine, DDL
from sqlalchemy.exc import SQLAlchemyError, OperationalError

if not try_to_create_db:
logger.info(f"Skipping creation of database {db_name}")
return
engine = create_engine(db_url)

try:
# Try to connect to the database
with engine.connect() as conn:
logger.info(f"Database {db_name} already exists")
return
except OperationalError as oe:
# If the error indicates that the database does not exist, try to create it
if "Unknown database" in str(oe):
try:
# Create the database
no_db_name_url = db_url.rsplit("/", 1)[0]
engine_no_db = create_engine(no_db_name_url)
with engine_no_db.connect() as conn:
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
logger.info(f"Database {db_name} successfully created")
except SQLAlchemyError as e:
logger.error(f"Failed to create database {db_name}: {e}")
raise
else:
logger.error(f"Error connecting to database {db_name}: {oe}")
raise


@dataclass
class WebServerParameters(BaseParameters):
host: Optional[str] = field(
Expand Down
Empty file added dbgpt/app/tests/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions dbgpt/app/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy.exc import OperationalError, SQLAlchemyError

from dbgpt.app.base import _create_mysql_database


@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_already_exists(mock_logger, mock_create_engine):
mock_connection = MagicMock()
mock_create_engine.return_value.connect.return_value.__enter__.return_value = (
mock_connection
)

_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db already exists")
mock_connection.execute.assert_not_called()


@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_success(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection success
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(),
]

_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db successfully created")


@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_failure(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection failure with SQLAlchemyError
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))),
]

with pytest.raises(SQLAlchemyError):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.error.assert_called_with(
"Failed to create database test_db: Creation failed"
)


@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_skip_database_creation(mock_logger, mock_create_engine):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", False
)
mock_logger.info.assert_called_with("Skipping creation of database test_db")
mock_create_engine.assert_not_called()

0 comments on commit 815cf5e

Please sign in to comment.