Skip to content

Commit

Permalink
including orm functions
Browse files Browse the repository at this point in the history
  • Loading branch information
TrueRou committed Sep 24, 2023
1 parent 6fc624a commit 4e8bce2
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
1 change: 1 addition & 0 deletions app/api/init_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def on_startup() -> None:
json_serialize=lambda x: orjson.dumps(x).decode(),
)
await app.state.services.database.connect()
await app.state.services.create_db_and_tables() # for sqlalchemy orm
await app.state.services.redis.initialize()

if app.state.services.datadog is not None:
Expand Down
91 changes: 91 additions & 0 deletions app/repositories/addition/orm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Generic, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine import ScalarResult
from sqlalchemy import select, delete, func

V = TypeVar("V")

async def add_model(session: AsyncSession, obj: V) -> V:
session.add(obj)
await session.commit()
await session.refresh(obj)
return obj


async def merge_model(session: AsyncSession, obj: V) -> V:
obj = await session.merge(obj)
await session.commit()
await session.refresh(obj)
return obj


async def delete_model(session: AsyncSession, ident, model):
target = await session.get(model, ident)
await session.delete(target)
await session.flush()
await session.commit() # Ensure deletion were operated


async def delete_models(session: AsyncSession, obj: Generic[V], condition):
sentence = delete(obj).where(condition)
await session.execute(sentence)


async def get_model(session: AsyncSession, ident, model: Generic[V]):
return await session.get(model, ident)


def _build_select_sentence(obj: Generic[V], condition=None, offset=-1, limit=-1, order_by=None):
return _enlarge_sentence(select(obj), condition, offset, limit, order_by)


def _enlarge_sentence(base, condition=None, offset=-1, limit=-1, order_by=None):
if condition is not None:
base = base.where(condition)
if order_by is not None:
base = base.order_by(order_by)
if offset != -1:
base = base.offset(offset)
if limit != -1:
base = base.limit(limit)
return base


async def select_model(session: AsyncSession, obj: Generic[V], condition=None, offset=-1, limit=-1, order_by=None) -> V:
sentence = _build_select_sentence(obj, condition, offset, limit, order_by)
model = await session.scalar(sentence)
return model


async def query_model(session: AsyncSession, sentence, condition=None, offset=-1, limit=-1, order_by=None):
sentence = _enlarge_sentence(sentence, condition, offset, limit, order_by)
model = await session.scalar(sentence)
return model


async def select_models(session: AsyncSession, obj: Generic[V], condition=None, offset=-1, limit=-1, order_by=None) -> ScalarResult[V]:
sentence = _build_select_sentence(obj, condition, offset, limit, order_by)
model = await session.scalars(sentence)
return model


async def query_models(session: AsyncSession, sentence, condition=None, offset=-1, limit=-1, order_by=None):
sentence = _enlarge_sentence(sentence, condition, offset, limit, order_by)
model = await session.scalars(sentence)
return model


async def select_models_count(session: AsyncSession, obj: Generic[V], condition=None, offset=-1, limit=-1, order_by=None) -> int:
sentence = _build_select_sentence(obj, condition, offset, limit, order_by)
sentence = sentence.with_only_columns(func.count(obj.id)).order_by(None)
model = await session.scalar(sentence)
return model


async def partial_update(session: AsyncSession, item: Generic[V], updates) -> V:
update_data = updates.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(item, key, value)
await session.commit()
await session.refresh(item)
return item
1 change: 1 addition & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def read_list(value: str) -> list[str]:
DB_PASS = os.environ["DB_PASS"]
DB_NAME = os.environ["DB_NAME"]
DB_DSN = f"mysql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
DB_DSN_ASYNC = f"mysql+aiomysql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"

REDIS_HOST = os.environ["REDIS_HOST"]
REDIS_PORT = int(os.environ["REDIS_PORT"])
Expand Down
19 changes: 18 additions & 1 deletion app/state/services.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import ipaddress
import pickle
import re
Expand All @@ -9,11 +10,13 @@
from collections.abc import Mapping
from collections.abc import MutableMapping
from pathlib import Path
from typing import Optional
from typing import AsyncContextManager, Optional
from typing import TYPE_CHECKING
from typing import TypedDict

import databases
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
import datadog as datadog_module
import datadog.threadstats.base as datadog_client
import pymysql
Expand Down Expand Up @@ -42,6 +45,11 @@

http_client: aiohttp.ClientSession
database = databases.Database(app.settings.DB_DSN)

_sqla_engine = create_async_engine(app.settings.DB_DSN_ASYNC, future=True)
_async_session_maker = sessionmaker(_sqla_engine, class_=AsyncSession, expire_on_commit=False)
orm_base = declarative_base()

redis: aioredis.Redis = aioredis.from_url(app.settings.REDIS_DSN, decode_responses=True)

datadog: datadog_client.ThreadStats | None = None
Expand Down Expand Up @@ -107,6 +115,15 @@ class Geolocation(TypedDict):
}
# fmt: on

async def create_db_and_tables():
async with _sqla_engine.begin() as conn:
await conn.run_sync(orm_base.metadata.create_all)

@contextlib.asynccontextmanager
async def db_session() -> AsyncContextManager[AsyncSession]:
async with _async_session_maker() as session:
yield session
await session.commit()

class IPResolver:
def __init__(self) -> None:
Expand Down

0 comments on commit 4e8bce2

Please sign in to comment.