Skip to content

Commit

Permalink
添加数据库版本迁移
Browse files Browse the repository at this point in the history
(cherry picked from commit 7669290)
  • Loading branch information
ssttkkl committed Feb 24, 2023
1 parent 0812808 commit b3c593b
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 14 deletions.
11 changes: 1 addition & 10 deletions src/nonebot_plugin_mahjong_scoreboard/model/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
from nonebot import get_driver, require

from nonebot_plugin_mahjong_scoreboard.config import conf

require("nonebot_plugin_sqlalchemy")
from nonebot_plugin_sqlalchemy import DataSource

data_source = DataSource(get_driver(), conf.mahjong_scoreboard_database_conn_url)

__all__ = ("data_source",)
from ._data_source import data_source
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import metainfo
from . import migrations
from .src import data_source, dialect

__all__ = ("data_source", "dialect")
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from sqlalchemy import Column, String, JSON, inspect, select
from sqlalchemy.ext.asyncio import AsyncSession

from .src import data_source

APP_DB_VERSION = 2


@data_source.registry.mapped
class MetaInfoOrm:
__tablename__ = 'metainfo'

key: str = Column(String(64), primary_key=True)
value = Column(JSON)


async def get_metainfo(key: str):
async with AsyncSession(data_source.engine) as session:
record = await session.get(MetaInfoOrm, key)
return record.value


async def set_metainfo(key: str, value: any):
async with AsyncSession(data_source.engine) as session:
record = await session.get(MetaInfoOrm, key)
if record is None:
record = MetaInfoOrm(key=key, value=value)
session.add(record)

record.value = value
await session.commit()


@data_source.on_engine_created
async def initialize_metainfo():
async with data_source.engine.begin() as conn:
await conn.run_sync(lambda conn: MetaInfoOrm.__table__.create(conn, checkfirst=True))

async with data_source.engine.begin() as conn:
async with AsyncSession(data_source.engine, expire_on_commit=False) as session:
# 判断是否初次建库
blank_database = not await conn.run_sync(lambda conn: inspect(conn).has_table("games"))
if blank_database:
result = MetaInfoOrm(key="db_version", value=APP_DB_VERSION)
session.add(result)
await session.commit()
else:
stmt = select(MetaInfoOrm).where(MetaInfoOrm.key == "db_version")
result = (await session.execute(stmt)).scalar_one_or_none()
if result is None:
result = MetaInfoOrm(key="db_version", value=1)
session.add(result)
await session.commit()

await conn.commit()
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from nonebot import logger

from .v1_to_v2 import migrate_v1_to_v2
from ..metainfo import get_metainfo, APP_DB_VERSION, set_metainfo
from ..src import data_source

migrations = {
(1, 2): migrate_v1_to_v2
}


@data_source.on_ready
async def do_migrate():
cur_db_version = await get_metainfo('db_version')
while cur_db_version < APP_DB_VERSION:
mig = migrations[cur_db_version, cur_db_version + 1]
await mig()
await set_metainfo('db_version', cur_db_version + 1)
logger.success(f"migrate database from version {cur_db_version} to version {cur_db_version + 1}")

cur_db_version += 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sqlalchemy import text

from ..src import data_source


async def migrate_v1_to_v2():
async with data_source.engine.begin() as conn:
await conn.execute(text("ALTER TABLE game_records ADD point_scale integer NOT NULL DEFAULT 0;"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from urllib.parse import urlparse

from nonebot import get_driver, require

from nonebot_plugin_mahjong_scoreboard.config import conf

require("nonebot_plugin_sqlalchemy")
from nonebot_plugin_sqlalchemy import DataSource

data_source = DataSource(get_driver(), conf.mahjong_scoreboard_database_conn_url)


def _detect_dialect():
url = urlparse(conf.mahjong_scoreboard_database_conn_url)
if '+' in url.scheme:
return url.scheme.split('+')[0]
else:
return url.scheme


dialect = _detect_dialect()

__all__ = ("data_source", "dialect")
2 changes: 1 addition & 1 deletion src/nonebot_plugin_mahjong_scoreboard/model/orm/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm import relationship

from nonebot_plugin_mahjong_scoreboard.model.enums import Wind
from . import data_source
from ._data_source import data_source
from ..enums import PlayerAndWind, GameState

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/nonebot_plugin_mahjong_scoreboard/model/orm/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy import Column, Integer, BigInteger, ForeignKey, Index
from sqlalchemy.orm import relationship

from . import data_source
from ._data_source import data_source

if TYPE_CHECKING:
from .season import SeasonOrm
Expand Down
2 changes: 1 addition & 1 deletion src/nonebot_plugin_mahjong_scoreboard/model/orm/season.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import Column, Integer, DateTime, String, Enum, ForeignKey, Boolean, Index
from sqlalchemy.orm import relationship

from . import data_source
from ._data_source import data_source
from .types.userdict import UserDict as SqlUserDict
from ..enums import SeasonState, SeasonUserPointChangeType
from ...utils.userdict import DictField
Expand Down
2 changes: 1 addition & 1 deletion src/nonebot_plugin_mahjong_scoreboard/model/orm/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy import Column, BigInteger, Integer, Index

from . import data_source
from ._data_source import data_source


@data_source.registry.mapped
Expand Down

0 comments on commit b3c593b

Please sign in to comment.