diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/__init__.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/__init__.py index 91a18c0..40c9986 100644 --- a/src/nonebot_plugin_mahjong_scoreboard/model/orm/__init__.py +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/__init__.py @@ -1,10 +1 @@ -from nonebot import get_driver, require - -from nonebot_plugin_mahjong_scoreboard.config import conf - -require("nonebot_plugin_sqlalchemy") -from nonebot_plugin_sqlalchemy import DataSource - -data_source = DataSource(get_driver(), conf.mahjong_scoreboard_database_conn_url) - -__all__ = ("data_source",) +from ._data_source import data_source diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/__init__.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/__init__.py new file mode 100644 index 0000000..81bf3c5 --- /dev/null +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/__init__.py @@ -0,0 +1,5 @@ +from . import metainfo +from . import migrations +from .src import data_source, dialect + +__all__ = ("data_source", "dialect") diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/metainfo.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/metainfo.py new file mode 100644 index 0000000..d82d1eb --- /dev/null +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/metainfo.py @@ -0,0 +1,55 @@ +from sqlalchemy import Column, String, JSON, inspect, select +from sqlalchemy.ext.asyncio import AsyncSession + +from .src import data_source + +APP_DB_VERSION = 2 + + +@data_source.registry.mapped +class MetaInfoOrm: + __tablename__ = 'metainfo' + + key: str = Column(String(64), primary_key=True) + value = Column(JSON) + + +async def get_metainfo(key: str): + async with AsyncSession(data_source.engine) as session: + record = await session.get(MetaInfoOrm, key) + return record.value + + +async def set_metainfo(key: str, value: any): + async with AsyncSession(data_source.engine) as session: + record = await session.get(MetaInfoOrm, key) + if record is None: + record = MetaInfoOrm(key=key, value=value) + session.add(record) + + record.value = value + await session.commit() + + +@data_source.on_engine_created +async def initialize_metainfo(): + async with data_source.engine.begin() as conn: + await conn.run_sync(lambda conn: MetaInfoOrm.__table__.create(conn, checkfirst=True)) + + async with data_source.engine.begin() as conn: + async with AsyncSession(data_source.engine, expire_on_commit=False) as session: + # 判断是否初次建库 + blank_database = not await conn.run_sync(lambda conn: inspect(conn).has_table("games")) + if blank_database: + result = MetaInfoOrm(key="db_version", value=APP_DB_VERSION) + session.add(result) + await session.commit() + else: + stmt = select(MetaInfoOrm).where(MetaInfoOrm.key == "db_version") + result = (await session.execute(stmt)).scalar_one_or_none() + if result is None: + result = MetaInfoOrm(key="db_version", value=1) + session.add(result) + await session.commit() + + await conn.commit() diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/migrations/__init__.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/migrations/__init__.py new file mode 100644 index 0000000..b32b522 --- /dev/null +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/migrations/__init__.py @@ -0,0 +1,21 @@ +from nonebot import logger + +from .v1_to_v2 import migrate_v1_to_v2 +from ..metainfo import get_metainfo, APP_DB_VERSION, set_metainfo +from ..src import data_source + +migrations = { + (1, 2): migrate_v1_to_v2 +} + + +@data_source.on_ready +async def do_migrate(): + cur_db_version = await get_metainfo('db_version') + while cur_db_version < APP_DB_VERSION: + mig = migrations[cur_db_version, cur_db_version + 1] + await mig() + await set_metainfo('db_version', cur_db_version + 1) + logger.success(f"migrate database from version {cur_db_version} to version {cur_db_version + 1}") + + cur_db_version += 1 diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/migrations/v1_to_v2.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/migrations/v1_to_v2.py new file mode 100644 index 0000000..fcff0f3 --- /dev/null +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/migrations/v1_to_v2.py @@ -0,0 +1,8 @@ +from sqlalchemy import text + +from ..src import data_source + + +async def migrate_v1_to_v2(): + async with data_source.engine.begin() as conn: + await conn.execute(text("ALTER TABLE game_records ADD point_scale integer NOT NULL DEFAULT 0;")) diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/src.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/src.py new file mode 100644 index 0000000..d9a8209 --- /dev/null +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/_data_source/src.py @@ -0,0 +1,23 @@ +from urllib.parse import urlparse + +from nonebot import get_driver, require + +from nonebot_plugin_mahjong_scoreboard.config import conf + +require("nonebot_plugin_sqlalchemy") +from nonebot_plugin_sqlalchemy import DataSource + +data_source = DataSource(get_driver(), conf.mahjong_scoreboard_database_conn_url) + + +def _detect_dialect(): + url = urlparse(conf.mahjong_scoreboard_database_conn_url) + if '+' in url.scheme: + return url.scheme.split('+')[0] + else: + return url.scheme + + +dialect = _detect_dialect() + +__all__ = ("data_source", "dialect") diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/game.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/game.py index 55660d5..52eedc9 100644 --- a/src/nonebot_plugin_mahjong_scoreboard/model/orm/game.py +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/game.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import relationship from nonebot_plugin_mahjong_scoreboard.model.enums import Wind -from . import data_source +from ._data_source import data_source from ..enums import PlayerAndWind, GameState if TYPE_CHECKING: diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/group.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/group.py index f7788fb..042c4a2 100644 --- a/src/nonebot_plugin_mahjong_scoreboard/model/orm/group.py +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/group.py @@ -3,7 +3,7 @@ from sqlalchemy import Column, Integer, BigInteger, ForeignKey, Index from sqlalchemy.orm import relationship -from . import data_source +from ._data_source import data_source if TYPE_CHECKING: from .season import SeasonOrm diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/season.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/season.py index d523a28..b648103 100644 --- a/src/nonebot_plugin_mahjong_scoreboard/model/orm/season.py +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/season.py @@ -5,7 +5,7 @@ from sqlalchemy import Column, Integer, DateTime, String, Enum, ForeignKey, Boolean, Index from sqlalchemy.orm import relationship -from . import data_source +from ._data_source import data_source from .types.userdict import UserDict as SqlUserDict from ..enums import SeasonState, SeasonUserPointChangeType from ...utils.userdict import DictField diff --git a/src/nonebot_plugin_mahjong_scoreboard/model/orm/user.py b/src/nonebot_plugin_mahjong_scoreboard/model/orm/user.py index 8508382..81d1191 100644 --- a/src/nonebot_plugin_mahjong_scoreboard/model/orm/user.py +++ b/src/nonebot_plugin_mahjong_scoreboard/model/orm/user.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, BigInteger, Integer, Index -from . import data_source +from ._data_source import data_source @data_source.registry.mapped