-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
92 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,103 @@ | ||
from loguru import logger | ||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | ||
from contextlib import asynccontextmanager | ||
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker | ||
from sqlalchemy.orm import sessionmaker | ||
from sqlmodel import SQLModel, create_engine | ||
|
||
from sqlmodel import SQLModel | ||
from typing import TypeVar | ||
|
||
from icon_stats.config import config | ||
|
||
SQLALCHEMY_DATABASE_URL_STUB = "://{user}:{password}@{server}:{port}/{db}".format( | ||
user=config.POSTGRES_USER, | ||
password=config.POSTGRES_PASSWORD, | ||
server=config.POSTGRES_SERVER, | ||
port=config.POSTGRES_PORT, | ||
db=config.POSTGRES_DATABASE, | ||
|
||
def create_conn_str( | ||
user: str, | ||
password: str, | ||
server: str, | ||
port: str, | ||
database: str, | ||
prefix: str = "postgresql+asyncpg", | ||
**kwargs, | ||
) -> str: | ||
return f"{prefix}://{user}:{password}@{server}:{port}/{database}" | ||
|
||
|
||
# ASYNC_SQLALCHEMY_DATABASE_URL = create_conn_str(**config.db.stats.__dict__) | ||
ASYNC_CONNECTION_STRING = create_conn_str( | ||
user="postgres", | ||
password="changeme", | ||
server="localhost", | ||
database="postgres", | ||
port="5432", | ||
) | ||
|
||
ASYNC_SQLALCHEMY_DATABASE_URL = "postgresql+asyncpg" + SQLALCHEMY_DATABASE_URL_STUB | ||
SQLALCHEMY_DATABASE_URL = "postgresql+psycopg2" + SQLALCHEMY_DATABASE_URL_STUB | ||
def create_db_connection_strings() -> dict[str, str]: | ||
connection_strings = {} | ||
|
||
for db, c in config.db.__dict__.items(): | ||
connection_strings.update({db: create_conn_str(**c.__dict__)}) | ||
|
||
return connection_strings | ||
|
||
|
||
# engines = {} | ||
# | ||
# engines = { | ||
# db: create_async_engine( | ||
# url, | ||
# connect_args={"options": f"-c search_path={config.}"}, | ||
# echo=True, | ||
# ) | ||
# for db, url in create_db_connection_strings().items() | ||
# } | ||
# | ||
# # A dict to hold sessions for each of the DBs being used | ||
# session_factories = { | ||
# db: async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) | ||
# for db, engine in engines.items() | ||
# } | ||
|
||
def create_session_factories() -> dict[str, async_sessionmaker]: | ||
output = {} | ||
|
||
for db_name, db_config in config.db.__dict__.items(): | ||
connection_string = create_conn_str(**db_config.__dict__) | ||
engine = create_async_engine( | ||
connection_string, | ||
# connect_args={"options": f"-c search_path={db_config.schema_}"}, | ||
echo=True, | ||
) | ||
session_factory = async_sessionmaker( | ||
bind=engine, | ||
class_=AsyncSession, | ||
expire_on_commit=False, | ||
) | ||
output.update({db_name: session_factory}) | ||
|
||
return output | ||
|
||
logger.info(f"Connecting to server: {config.POSTGRES_SERVER} and {config.POSTGRES_DATABASE}") | ||
|
||
async_engine = create_async_engine(ASYNC_SQLALCHEMY_DATABASE_URL, echo=True, future=True) | ||
session_factories = create_session_factories() | ||
|
||
@asynccontextmanager | ||
async def get_session(db_name: str): | ||
if db_name not in session_factories: | ||
raise ValueError( | ||
f"No session factory registered for database key: {db_name}") | ||
|
||
# Run onetime if we want to init with a prebuilt table of attributes | ||
async def init_db(): | ||
async with async_engine.begin() as conn: | ||
# await conn.run_sync(SQLModel.metadata.drop_all) | ||
await conn.run_sync(SQLModel.metadata.create_all) | ||
async with session_factories[db_name]() as session: | ||
try: | ||
yield session | ||
await session.commit() | ||
except Exception as e: | ||
await session.rollback() | ||
raise e | ||
|
||
|
||
async def get_session() -> AsyncSession: | ||
async_session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) | ||
async with async_session() as session: | ||
yield session | ||
# Generic sqlmodel table | ||
TDbModel = TypeVar("TDbModel", bound=SQLModel) | ||
|
||
|
||
engine = create_engine(SQLALCHEMY_DATABASE_URL) | ||
session_factory = sessionmaker(bind=engine) | ||
async def upsert_model(db_name: str, model: TDbModel): | ||
async with get_session(db_name=db_name) as session: | ||
session.add(model) | ||
await session.commit() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters