From bdf94423937c6995663ea81cd510c9325948dda1 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 16 Dec 2023 20:59:54 +0800 Subject: [PATCH] fix(core): Fix not create database bug (#946) --- dbgpt/app/base.py | 47 ++++++++++++++++++++++++ dbgpt/app/tests/__init__.py | 0 dbgpt/app/tests/test_base.py | 71 ++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 dbgpt/app/tests/__init__.py create mode 100644 dbgpt/app/tests/test_base.py diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index a8663876e..b911f24a9 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -2,6 +2,7 @@ import os import threading import sys +import logging from typing import Optional from dataclasses import dataclass, field @@ -14,6 +15,8 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) +logger = logging.getLogger(__name__) + def signal_handler(sig, frame): print("in order to avoid chroma db atexit problem") @@ -110,6 +113,8 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: os.makedirs(default_meta_data_path, exist_ok=True) if CFG.LOCAL_DB_TYPE == "mysql": db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}" + # Try to create database, if failed, will raise exception + _create_mysql_database(db_name, db_url, try_to_create_db) else: sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db") db_url = f"sqlite:///{sqlite_db_path}" @@ -124,6 +129,48 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: return default_meta_data_path +def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = False): + """Create mysql database if not exists + + Args: + db_name (str): The database name + db_url (str): The database url, include host, port, user, password and database name + try_to_create_db (bool, optional): Whether to try to create database. Defaults to False. + + Raises: + Exception: Raise exception if database operation failed + """ + from sqlalchemy import create_engine, DDL + from sqlalchemy.exc import SQLAlchemyError, OperationalError + + if not try_to_create_db: + logger.info(f"Skipping creation of database {db_name}") + return + engine = create_engine(db_url) + + try: + # Try to connect to the database + with engine.connect() as conn: + logger.info(f"Database {db_name} already exists") + return + except OperationalError as oe: + # If the error indicates that the database does not exist, try to create it + if "Unknown database" in str(oe): + try: + # Create the database + no_db_name_url = db_url.rsplit("/", 1)[0] + engine_no_db = create_engine(no_db_name_url) + with engine_no_db.connect() as conn: + conn.execute(DDL(f"CREATE DATABASE {db_name}")) + logger.info(f"Database {db_name} successfully created") + except SQLAlchemyError as e: + logger.error(f"Failed to create database {db_name}: {e}") + raise + else: + logger.error(f"Error connecting to database {db_name}: {oe}") + raise + + @dataclass class WebServerParameters(BaseParameters): host: Optional[str] = field( diff --git a/dbgpt/app/tests/__init__.py b/dbgpt/app/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/app/tests/test_base.py b/dbgpt/app/tests/test_base.py new file mode 100644 index 000000000..4f5a21a9f --- /dev/null +++ b/dbgpt/app/tests/test_base.py @@ -0,0 +1,71 @@ +import pytest +from unittest.mock import patch, MagicMock +from sqlalchemy.exc import OperationalError, SQLAlchemyError + +from dbgpt.app.base import _create_mysql_database + + +@patch("sqlalchemy.create_engine") +@patch("dbgpt.app.base.logger") +def test_database_already_exists(mock_logger, mock_create_engine): + mock_connection = MagicMock() + mock_create_engine.return_value.connect.return_value.__enter__.return_value = ( + mock_connection + ) + + _create_mysql_database( + "test_db", "mysql+pymysql://user:password@host/test_db", True + ) + mock_logger.info.assert_called_with("Database test_db already exists") + mock_connection.execute.assert_not_called() + + +@patch("sqlalchemy.create_engine") +@patch("dbgpt.app.base.logger") +def test_database_creation_success(mock_logger, mock_create_engine): + # Mock the first connection failure, and the second connection success + mock_create_engine.side_effect = [ + MagicMock( + connect=MagicMock( + side_effect=OperationalError("Unknown database", None, None) + ) + ), + MagicMock(), + ] + + _create_mysql_database( + "test_db", "mysql+pymysql://user:password@host/test_db", True + ) + mock_logger.info.assert_called_with("Database test_db successfully created") + + +@patch("sqlalchemy.create_engine") +@patch("dbgpt.app.base.logger") +def test_database_creation_failure(mock_logger, mock_create_engine): + # Mock the first connection failure, and the second connection failure with SQLAlchemyError + mock_create_engine.side_effect = [ + MagicMock( + connect=MagicMock( + side_effect=OperationalError("Unknown database", None, None) + ) + ), + MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))), + ] + + with pytest.raises(SQLAlchemyError): + _create_mysql_database( + "test_db", "mysql+pymysql://user:password@host/test_db", True + ) + mock_logger.error.assert_called_with( + "Failed to create database test_db: Creation failed" + ) + + +@patch("sqlalchemy.create_engine") +@patch("dbgpt.app.base.logger") +def test_skip_database_creation(mock_logger, mock_create_engine): + _create_mysql_database( + "test_db", "mysql+pymysql://user:password@host/test_db", False + ) + mock_logger.info.assert_called_with("Skipping creation of database test_db") + mock_create_engine.assert_not_called()