diff --git a/.env b/.env new file mode 100644 index 0000000..f9ea633 --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +WEB3_HTTPS_PROVIDER_URI=https://eth-mainnet.g.alchemy.com/v2/422IpViRAru0Uu1SANhySuOStpaIK3AG +ALCHEMY_API_KEY=xxx +QUART_APP=quart_sqlalchemy.sim.main:app diff --git a/docs/Simulation.md b/docs/Simulation.md new file mode 100644 index 0000000..656d8df --- /dev/null +++ b/docs/Simulation.md @@ -0,0 +1,159 @@ +# Simulation Docs + +# initialize database +```shell +quart db create +``` +``` +Initialized database schema for +``` + +# add first client to the database (Using CLI) +```shell +quart auth add-client +``` +``` +Created client 2VolejRejNmG with public_api_key: 5f794cf72d0cef2dd008be2c0b7a632b +``` + +Use the `public_api_key` returned for the value of the `X-Public-API-Key` header when making API requests. + + +# Create new auth_user via api +```shell +curl -X POST localhost:8081/api/auth_user/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' \ + --data '{"email": "joe2@joe.com"}' +``` +```json +{ + "data": { + "auth_user": { + "client_id": "2VolejRejNmG", + "current_session_token": "69ee9af5b9296a09f90be5b71c1dda38", + "date_verified": 1681344793, + "delegated_identity_pool_id": null, + "delegated_user_id": null, + "email": "joe2@joe.com", + "global_auth_user_id": null, + "id": "GWpmbk5ezJn4", + "is_admin": false, + "linked_primary_auth_user_id": null, + "phone_number": null, + "provenance": null, + "user_type": 2 + } + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +Use the `current_session_token` returned for the value of `Authorization: Bearer {token}` header when making API Requests requiring a user. + +# get AuthUser corresponding to provided bearer session token +```shell +curl -X GET localhost:8081/api/auth_user/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Authorization: Bearer 69ee9af5b9296a09f90be5b71c1dda38' \ + -H 'Content-Type: application/json' +``` +```json +{ + "data": { + "client_id": "2VolejRejNmG", + "current_session_token": "69ee9af5b9296a09f90be5b71c1dda38", + "date_verified": 1681344793, + "delegated_identity_pool_id": null, + "delegated_user_id": null, + "email": "joe2@joe.com", + "global_auth_user_id": null, + "id": "GWpmbk5ezJn4", + "is_admin": false, + "linked_primary_auth_user_id": null, + "phone_number": null, + "provenance": null, + "user_type": 2 + }, + "error_code": "", + "message": "", + "status": "" +} +``` + + +# AuthWallet Sync +```shell +curl -X POST localhost:8081/api/auth_wallet/sync \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Authorization: Bearer 69ee9af5b9296a09f90be5b71c1dda38' \ + -H 'Content-Type: application/json' \ + --data '{"public_address": "xxx", "encrypted_private_address": "xxx", "wallet_type": "ETH"}' +``` +```json +{ + "data": { + "auth_user_id": "GWpmbk5ezJn4", + "encrypted_private_address": "xxx", + "public_address": "xxx", + "wallet_id": "GWpmbk5ezJn4", + "wallet_type": "ETH" + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +# get magic client corresponding to provided public api key +```shell +curl -X GET localhost:8081/api/magic_client/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' +``` +```json +{ + "data": { + "app_name": "My App", + "connect_interop": null, + "global_audience_enabled": false, + "id": "2VolejRejNmG", + "is_signing_modal_enabled": false, + "public_api_key": "5f794cf72d0cef2dd008be2c0b7a632b", + "rate_limit_tier": null, + "secret_api_key": "c6ecbced505b35505751c862ed0fb10ffb623d24095019433e0d4d94e240e508" + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +# Create new magic client +```shell +curl -X POST localhost:8081/api/magic_client/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' \ + --data '{"app_name": "New App"}' +``` +```json +{ + "data": { + "magic_client": { + "app_name": "New App", + "connect_interop": null, + "global_audience_enabled": false, + "id": "GWpmbk5ezJn4", + "is_signing_modal_enabled": false, + "public_api_key": "fb7e0466e2e09387b93af7da49bb1386", + "rate_limit_tier": null, + "secret_api_key": "2ac56a6068d0d4b2ce911ba08401c7bf4acdb03db957550c260bd317c6c49a76" + } + }, + "error_code": "", + "message": "", + "status": "" +} +``` \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..f0bfdbe --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,300 @@ +# API + +## SQLAlchemy +### `quart_sqlalchemy.sqla.SQLAlchemy` + +### Conventions +This manager class keeps things very simple by using a few configuration conventions: + +* Configuration has been simplified down to base_class and binds. +* Everything related to ORM mapping, DeclarativeBase, registry, MetaData, etc should be configured by passing the a custom DeclarativeBase class as the base_class configuration parameter. +* Everything related to engine/session configuration should be configured by passing a dictionary mapping string names to BindConfigs as the `binds` configuration parameter. +* the bind named `default` is the canonical bind, and to be used unless something more specific has been requested + +### Configuration +BindConfig can be as simple as a dictionary containing a url key like so: +```python +bind_config = { + "default": {"url": "sqlite://"} +} +``` + +But most use cases will require more than just a connection url, and divide core/engine configuration from orm/session configuration which looks more like this: +```python +bind_config = { + "default": { + "engine": { + "url": "sqlite://" + }, + "session": { + "expire_on_commit": False + } + } +} +``` + +It helps to think of the bind configuration as being the options dictionary used to build the main core and orm factory objects. +* For SQLAlchemy core, the configuration under the key `engine` will be used by `sa.engine_from_config` to build the `sa.Engine` object which acts as a factory for `sa.Connection` objects. + ```python + engine = sa.engine_from_config(config.engine, prefix="") + ``` +* For SQLAlchemy orm, the configuration under the key `session` will be used to build the `sa.orm.sessionmaker` session factory which acts as a factory for `sa.orm.Session` objects. + ```python + session_factory = sa.orm.sessionmaker(bind=engine, **config.session) + ``` + +#### Usage Examples +SQLAlchemyConfig is to be passed to SQLAlchemy or QuartSQLAlchemy as the first parameter when initializing. + +```python +db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + url="sqlite://" + ) + ) + ) +) +``` + +When nothing is provided to SQLAlchemyConfig directly, it is instantiated with the following defaults + +```python +db = SQLAlchemy(SQLAlchemyConfig()) +``` + +For `QuartSQLAlchemy` configuration can also be provided via Quart configuration. +```python +from quart_sqlalchemy.framework import QuartSQLAlchemy + +app = Quart(__name__) +app.config.from_mapping( + { + "SQLALCHEMY_BINDS": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + } + }, + "SQLALCHEMY_BASE_CLASS": Base, + } +) +db = QuartSQLAlchemy(app=app) +``` + + + + +A typical configuration containing engine and session config both: +```python +config = SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite://" + ), + session=dict( + expire_on_commit=False + ) + ) + ) +) +``` + +Async first configuration +```python +config = SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" + ), + session=dict( + expire_on_commit=False + ) + ) + ) +) +``` + +More complex configuration having two additional binds based on default, one for a read-replica and the second having an async driver + +```python +config = { + "SQLALCHEMY_BINDS": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": {"url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + }, + "SQLALCHEMY_BASE_CLASS": Base, +} +``` + + + Once instantiated, operations targetting all of the binds, aka metadata, like + `metadata.create_all` should be called from this class. Operations specific to a bind + should be called from that bind. This class has a few ways to get a specific bind. + + * To get a Bind, you can call `.get_bind(name)` on this class. The default bind can be + referenced at `.bind`. + + * To define an ORM model using the Base class attached to this class, simply inherit + from `.Base` + + db = SQLAlchemy(SQLAlchemyConfig()) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + * You can also decouple Base from SQLAlchemy with some dependency inversion: + from quart_sqlalchemy.model.mixins import DynamicArgsMixin, ReprMixin, TableNameMixin + + class Base(DynamicArgsMixin, ReprMixin, TableNameMixin): + __abstract__ = True + + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db = SQLAlchemy(SQLAlchemyConfig(bind_class=Base)) + + db.create_all() + + + Declarative Mapping using registry based decorator: + + db = SQLAlchemy(SQLAlchemyConfig()) + + @db.registry.mapped + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative with Imperative Table (Hybrid Declarative): + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + + Declarative using reflection to automatically build the table object: + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + autoload_with=db.bind.engine, + ) + + + Declarative Dataclass Mapping: + + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative Dataclass Mapping (using decorator): + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + @db.registry.mapped_as_dataclass + class User: + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Alternate Dataclass Provider Pattern: + + from pydantic.dataclasses import dataclass + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_, dataclass_callable=dataclass): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + Imperative style Mapping + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + user_table = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + post_table = sa.Table( + "post", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("title", sa.String, default="My post"), + sa.Column("user_id", sa.ForeignKey("user.id"), nullable=False), + ) + + class User: + pass + + class Post: + pass + + db.registry.map_imperatively( + User, + user_table, + properties={ + "posts": sa.orm.relationship(Post, back_populates="user") + } + ) + db.registry.map_imperatively( + Post, + post_table, + properties={ + "user": sa.orm.relationship(User, back_populates="posts", uselist=False) + } + ) \ No newline at end of file diff --git a/examples/decorators/provide_session.py b/examples/decorators/provide_session.py new file mode 100644 index 0000000..f67f2b0 --- /dev/null +++ b/examples/decorators/provide_session.py @@ -0,0 +1,57 @@ +import inspect +import typing as t +from contextlib import contextmanager +from functools import wraps + + +RT = t.TypeVar("RT") + + +@contextmanager +def create_session(bind): + """Contextmanager that will create and teardown a session.""" + session = bind.Session() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + +def provide_session(bind_name: str = "default"): + """ + Function decorator that provides a session if it isn't provided. + If you want to reuse a session or run the function as part of a + database transaction, you pass it to the function, if not this wrapper + will create one and close it for you. + """ + + def decorator(func: t.Callable[..., RT]) -> t.Callable[..., RT]: + from quart_sqlalchemy import Bind + + func_params = inspect.signature(func).parameters + try: + # func_params is an ordered dict -- this is the "recommended" way of getting the position + session_args_idx = tuple(func_params).index("session") + except ValueError: + raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None + + # We don't need this anymore -- ensure we don't keep a reference to it by mistake + del func_params + + @wraps(func) + def wrapper(*args, **kwargs) -> RT: + if "session" in kwargs or session_args_idx < len(args): + return func(*args, **kwargs) + else: + bind = Bind.get_instance(bind_name) + + with create_session(bind) as session: + return func(*args, session=session, **kwargs) + + return wrapper + + return decorator diff --git a/examples/repository/base.py b/examples/repository/base.py index ec70e39..3ae0d2e 100644 --- a/examples/repository/base.py +++ b/examples/repository/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t from abc import ABCMeta from abc import abstractmethod @@ -14,42 +15,49 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT sa = sqlalchemy -class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface.""" - identity: t.Type[EntityIdT] + # entity: t.Type[EntityT] - # def __init__(self, model: t.Type[EntityT]): - # self.model = model + # def __init__(self, entity: t.Type[EntityT]): + # self.entity = entity @property - def model(self) -> EntityT: + def entity(self) -> EntityT: return self.__orig_class__.__args__[0] @abstractmethod - def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + def insert(self, session: SessionT, values: t.Dict[str, t.Any]) -> EntityT: """Add `values` to the collection.""" @abstractmethod - def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + def update(self, session: SessionT, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: """Update model with model_id using values.""" @abstractmethod def merge( - self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + self, + session: SessionT, + id_: EntityIdT, + values: t.Dict[str, t.Any], + for_update: bool = False, ) -> EntityT: """Merge model with model_id using values.""" @abstractmethod def get( self, + session: SessionT, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -58,9 +66,28 @@ def get( ) -> t.Optional[EntityT]: """Get model with model_id.""" + @abstractmethod + def get_by_field( + self, + session: SessionT, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + @abstractmethod def select( self, + session: SessionT, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -77,12 +104,13 @@ def select( """Select models matching conditions.""" @abstractmethod - def delete(self, id_: EntityIdT) -> None: + def delete(self, session: SessionT, id_: EntityIdT) -> None: """Delete model with id_.""" @abstractmethod def exists( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), for_update: bool = False, include_inactive: bool = False, @@ -90,27 +118,31 @@ def exists( """Return the existence of an object matching conditions.""" @abstractmethod - def deactivate(self, id_: EntityIdT) -> EntityT: + def deactivate(self, session: SessionT, id_: EntityIdT) -> EntityT: """Soft-Delete model with id_.""" @abstractmethod - def reactivate(self, id_: EntityIdT) -> EntityT: + def reactivate(self, session: SessionT, id_: EntityIdT) -> EntityT: """Soft-Delete model with id_.""" -class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. Note: this interface circumvents ORM internals, breaking commonly expected behavior in order to gain performance benefits. Only use this class whenever absolutely necessary. """ - model: t.Type[EntityT] builder: StatementBuilder + @property + def entity(self) -> EntityT: + return self.__orig_class__.__args__[0] + @abstractmethod def bulk_insert( self, + session: SessionT, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: @@ -119,6 +151,7 @@ def bulk_insert( @abstractmethod def bulk_update( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -128,6 +161,7 @@ def bulk_update( @abstractmethod def bulk_delete( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: diff --git a/examples/repository/sqla.py b/examples/repository/sqla.py index 417ddaf..5f77dae 100644 --- a/examples/repository/sqla.py +++ b/examples/repository/sqla.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t import sqlalchemy @@ -15,6 +16,7 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT @@ -25,8 +27,8 @@ class SQLAlchemyRepository( TableMetadataMixin, - AbstractRepository[EntityT, EntityIdT], - t.Generic[EntityT, EntityIdT], + AbstractRepository[EntityT, EntityIdT, SessionT], + t.Generic[EntityT, EntityIdT, SessionT], ): """A repository that uses SQLAlchemy to persist data. @@ -53,7 +55,7 @@ class SQLAlchemyRepository( session: sa.orm.Session builder: StatementBuilder - def __init__(self, session: sa.orm.Session, **kwargs): + def __init__(self, model: sa.orm.Session, **kwargs): super().__init__(**kwargs) self.session = session self.builder = StatementBuilder(None) @@ -125,6 +127,46 @@ def get( return self.session.scalars(statement, execution_options=execution_options).one_or_none() + def get_by_field( + self, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + selectables = (self.model,) # type: ignore + + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + if isinstance(field, str): + field = getattr(self.model, field) + + conditions = [t.cast(ColumnExpr, op(field, value))] + + statement = self.builder.complex_select( + selectables, + conditions=conditions, + order_by=order_by, + options=options, + execution_options=execution_options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + return self.session.scalars(statement) + def select( self, selectables: t.Sequence[Selectable] = (), diff --git a/examples/usrsrv/component/__init__.py b/examples/usrsrv/component/__init__.py new file mode 100644 index 0000000..21b27e9 --- /dev/null +++ b/examples/usrsrv/component/__init__.py @@ -0,0 +1,23 @@ +from . import commands +from . import events +from . import exceptions +from .app import handler +from .entity import EntityID +from .service import CommandHandler +from .service import Listener + + +handle = handler.handle +register = handler.register +unregister = handler.unregister + +__all__ = [ + "commands", + "events", + "exceptions", + "EntityID", + "CommandHandler", + "handle", + "register", + "unregister", +] diff --git a/examples/usrsrv/component/app.py b/examples/usrsrv/component/app.py new file mode 100644 index 0000000..b18baee --- /dev/null +++ b/examples/usrsrv/component/app.py @@ -0,0 +1,10 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from .repository import ORMRepository +from .service import CommandHandler + + +some_engine = create_engine("sqlite:///") +Session = sessionmaker(bind=some_engine) +handler = CommandHandler(ORMRepository(Session())) diff --git a/examples/usrsrv/component/command.py b/examples/usrsrv/component/command.py new file mode 100644 index 0000000..47f59b0 --- /dev/null +++ b/examples/usrsrv/component/command.py @@ -0,0 +1,43 @@ +""" +Commands +======== +A command is always DTO and as specific, as it can be from a domain perspective. I aim to create +separate classes for commands so I can just dispatch handlers by command class. + +```python +@dataclass +class Create(Command): + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) +``` +""" + +from abc import ABC +from dataclasses import dataclass, field +from datetime import datetime +from typing import Text +from uuid import UUID, uuid1 + +from .entity import EntityID + +CommandID = UUID + + +class Command(ABC): + entity_id: EntityID + command_id: CommandID + timestamp: datetime + + +@dataclass +class Create(Command): + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class UpdateValue(Command): + entity_id: EntityID + value: Text + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) diff --git a/examples/usrsrv/component/entity.py b/examples/usrsrv/component/entity.py new file mode 100644 index 0000000..a910ad2 --- /dev/null +++ b/examples/usrsrv/component/entity.py @@ -0,0 +1,39 @@ +from typing import NewType +from typing import Optional +from typing import Text +from uuid import UUID +from uuid import uuid1 + + +EntityID = NewType("EntityID", UUID) + + +class EntityDTO: + id: EntityID + value: Optional[Text] + + +class Entity: + id: EntityID + dto: EntityDTO + + class Event: + pass + + class Updated(Event): + pass + + def __init__(self, dto: EntityDTO) -> None: + self.id = dto.id + self.dto = dto + + @classmethod + def create(cls) -> "Entity": + dto = EntityDTO() + dto.id = EntityID(uuid1()) + dto.value = None + return Entity(dto) + + def update(self, value: Text) -> Updated: + self.dto.value = value + return self.Updated() diff --git a/examples/usrsrv/component/event.py b/examples/usrsrv/component/event.py new file mode 100644 index 0000000..4dd2895 --- /dev/null +++ b/examples/usrsrv/component/event.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from functools import singledispatch +from uuid import UUID +from uuid import uuid1 + +from .command import Command +from .command import CommandID +from .entity import Entity +from .entity import EntityID + + +EventID = UUID + + +class Event: + command_id: CommandID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class Created(Event): + command_id: CommandID + uow_id: EntityID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class Updated(Event): + command_id: CommandID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@singledispatch +def app_event(event: Entity.Event, command: Command) -> Event: + raise NotImplementedError + + +@app_event.register(Entity.Updated) +def _(event: Entity.Updated, command: Command) -> Updated: + return Updated(command.command_id) diff --git a/examples/usrsrv/component/exception.py b/examples/usrsrv/component/exception.py new file mode 100644 index 0000000..f88096f --- /dev/null +++ b/examples/usrsrv/component/exception.py @@ -0,0 +1,2 @@ +class NotFound(Exception): + pass diff --git a/examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py b/examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/usrsrv/component/repository.py b/examples/usrsrv/component/repository.py new file mode 100644 index 0000000..44dffcd --- /dev/null +++ b/examples/usrsrv/component/repository.py @@ -0,0 +1,63 @@ +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.orm import registry +from sqlalchemy.orm import Session + +from . import EntityID +from .entity import Entity +from .entity import EntityDTO +from .exception import NotFound +from .service import Repository + + +metadata = MetaData() +mapper_registry = registry(metadata=metadata) + + +entities_table = Table( + "entities", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("uuid", String, unique=True, index=True), + Column("value", String, nullable=True), +) + +# EntityMapper = mapper( +# EntityDTO, +# entities_table, +# properties={ +# "id": entities_table.c.uuid, +# "value": entities_table.c.value, +# }, +# column_prefix="_db_column_", +# ) + +EntityMapper = mapper_registry.map_imperatively( + EntityDTO, + entities_table, + properties={ + "id": entities_table.c.uuid, + "value": entities_table.c.value, + }, + column_prefix="_db_column_", +) + + +class ORMRepository(Repository): + def __init__(self, session: Session): + self._session = session + self._query = select(EntityMapper) + + def get(self, entity_id: EntityID) -> Entity: + dto = self._session.scalars(self._query.filter_by(uuid=entity_id)).one_or_none() + if not dto: + raise NotFound(entity_id) + return Entity(dto) + + def save(self, entity: Entity) -> None: + self._session.add(entity.dto) + self._session.flush() diff --git a/examples/usrsrv/component/service.py b/examples/usrsrv/component/service.py new file mode 100644 index 0000000..9846736 --- /dev/null +++ b/examples/usrsrv/component/service.py @@ -0,0 +1,68 @@ +from abc import ABC +from abc import abstractmethod +from functools import singledispatch +from typing import Callable +from typing import List +from typing import Optional + +from .command import Command +from .command import Create +from .command import UpdateValue +from .entity import Entity +from .entity import EntityID +from .event import app_event +from .event import Created +from .event import Event + + +Listener = Callable[[Event], None] + + +class Repository(ABC): + @abstractmethod + def get(self, entity_id: EntityID) -> Entity: + raise NotImplementedError + + @abstractmethod + def save(self, entity: Entity) -> None: + raise NotImplementedError + + +class CommandHandler: + def __init__(self, repository: Repository) -> None: + self._repository = repository + self._listeners: List[Listener] = [] + super().__init__() + + def register(self, listener: Listener) -> None: + if listener not in self._listeners: + self._listeners.append(listener) + + def unregister(self, listener: Listener) -> None: + if listener in self._listeners: + self._listeners.remove(listener) + + @singledispatch + def handle(self, command: Command) -> Optional[Event]: + entity: Entity = self._repository.get(command.entity_id) + + event: Event = app_event(self._handle(command, entity), command) + for listener in self._listeners: + listener(event) + + self._repository.save(entity) + return event + + @handle.register(Create) + def create(self, command: Create) -> Event: + entity = Entity.create() + self._repository.save(entity) + return Created(command.command_id, entity.id) + + @singledispatch + def _handle(self, c: Command, u: Entity) -> Entity.Event: + raise NotImplementedError + + @_handle.register(UpdateValue) + def _(self, command: UpdateValue, entity: Entity) -> Entity.Event: + return entity.update(command.value) diff --git a/pyproject.toml b/pyproject.toml index 8e2348c..17e8c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "pydantic", "tenacity", "sqlapagination", - "exceptiongroup" + "exceptiongroup", + "python-ulid" ] requires-python = ">=3.7" readme = "README.rst" @@ -30,6 +31,9 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] +sim = [ + "quart-schema", "hashids", "web3", "dependency-injector", +] tests = [ "pytest", # "pytest-asyncio~=0.20.3", @@ -118,7 +122,7 @@ ignore_missing_imports = true [tool.pylint.messages_control] max-line-length = 100 -disable = ["missing-docstring", "protected-access"] +disable = ["invalid-name", "missing-docstring", "protected-access"] [tool.flakeheaven] baseline = ".flakeheaven_baseline" diff --git a/setup.cfg b/setup.cfg index ee9c148..0c8f6eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ ignore = WPS463 allowed-domain-names = + db value val vals diff --git a/src/quart_sqlalchemy/__init__.py b/src/quart_sqlalchemy/__init__.py index b85cd50..28059e4 100644 --- a/src/quart_sqlalchemy/__init__.py +++ b/src/quart_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ __version__ = "3.0.2" - +from . import util from .bind import AsyncBind from .bind import Bind from .bind import BindContext diff --git a/src/quart_sqlalchemy/bind.py b/src/quart_sqlalchemy/bind.py index 8a69eaa..dce08c2 100644 --- a/src/quart_sqlalchemy/bind.py +++ b/src/quart_sqlalchemy/bind.py @@ -1,8 +1,12 @@ from __future__ import annotations import os +import threading import typing as t +from contextlib import asynccontextmanager from contextlib import contextmanager +from contextlib import ExitStack +from weakref import WeakValueDictionary import sqlalchemy import sqlalchemy.event @@ -11,6 +15,7 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +import typing_extensions as tx from . import signals from .config import BindConfig @@ -21,8 +26,16 @@ sa = sqlalchemy +SqlAMode = tx.Literal["orm", "core"] + + +class BindNotInitialized(RuntimeError): + """ "Bind not initialized yet.""" + class BindBase: + name: t.Optional[str] + url: sa.URL config: BindConfig metadata: sa.MetaData engine: sa.Engine @@ -30,66 +43,157 @@ class BindBase: def __init__( self, - config: BindConfig, - metadata: sa.MetaData, + name: t.Optional[str] = None, + url: t.Union[sa.URL, str] = "sqlite://", + config: t.Optional[BindConfig] = None, + metadata: t.Optional[sa.MetaData] = None, ): - self.config = config - self.metadata = metadata - - @property - def url(self) -> str: - if not hasattr(self, "engine"): - raise RuntimeError("Database not initialized yet. Call initialize() first.") - return str(self.engine.url) + self.name = name + self.url = sa.make_url(url) + self.config = config or BindConfig.default() + self.metadata = metadata or sa.MetaData() @property def is_async(self) -> bool: - if not hasattr(self, "engine"): - raise RuntimeError("Database not initialized yet. Call initialize() first.") - return self.engine.url.get_dialect().is_async + return self.url.get_dialect().is_async @property - def is_read_only(self): + def is_read_only(self) -> bool: return self.config.read_only + def __repr__(self) -> str: + parts = [type(self).__name__] + if self.name: + parts.append(self.name) + if self.url: + parts.append(str(self.url)) + if self.is_read_only: + parts.append("[read-only]") + + return f"<{' '.join(parts)}>" + class BindContext(BindBase): pass class Bind(BindBase): + lock: threading.Lock + _instances: WeakValueDictionary = WeakValueDictionary() + def __init__( self, - config: BindConfig, - metadata: sa.MetaData, + name: t.Optional[str] = None, + url: t.Union[sa.URL, str] = "sqlite://", + config: t.Optional[BindConfig] = None, + metadata: t.Optional[sa.MetaData] = None, initialize: bool = True, + track_instance: bool = False, ): - self.config = config - self.metadata = metadata + super().__init__(name, url, config, metadata) + self._initialization_lock = threading.Lock() + + if track_instance: + self._track_instance(name) if initialize: self.initialize() - def initialize(self): - if hasattr(self, "engine"): - self.engine.dispose() + self._session_stack = [] - self.engine = self.create_engine( - self.config.engine.dict(exclude_unset=True, exclude_none=True), - prefix="", - ) - self.Session = self.create_session_factory( - self.config.session.dict(exclude_unset=True, exclude_none=True), - ) + def initialize(self) -> tx.Self: + with self._initialization_lock: + if hasattr(self, "engine"): + self.engine.dispose() + + engine_config = self.config.engine.dict(exclude_unset=True, exclude_none=True) + engine_config.setdefault("url", self.url) + self.engine = self.create_engine(engine_config, prefix="") + + session_options = self.config.session.dict(exclude_unset=True, exclude_none=True) + self.Session = self.create_session_factory(session_options) return self + def _track_instance(self, name): + if name is None: + return + + if name in Bind._instances: + raise ValueError("Bind instance `{name}` already exists, use another name.") + else: + Bind._instances[name] = self + + @classmethod + def get_instance(cls, name: str = "default") -> Bind: + """Get the singleton instance having `name`. + + This enables some really cool patterns similar to how logging allows getting an already + initialized logger from anywhere without importing it directly. Features like this are + most useful when working in web frameworks like flask and quart that are more prone to + circular dependency issues. + + Example: + app/db.py: + from quart_sqlalchemy import Bind + + default = Bind("default", url="sqlite://") + + with default.Session() as session: + with session.begin(): + session.add(User()) + + + app/views/v1/user/login.py + from quart_sqlalchemy import Bind + + # get the same `default` bind already instantiated in app/db.py + default = Bind.get_instance("default") + + with default.Session() as session: + with session.begin(): + session.add(User()) + ... + """ + try: + return Bind._instances[name]() + except KeyError as err: + raise ValueError(f"Bind instance `{name}` does not exist.") from err + + @t.overload + @contextmanager + def transaction(self, mode: SqlAMode = "orm") -> t.Generator[sa.orm.Session, None, None]: + ... + + @t.overload + @contextmanager + def transaction(self, mode: SqlAMode = "core") -> t.Generator[sa.Connection, None, None]: + ... + + @contextmanager + def transaction( + self, mode: SqlAMode = "orm" + ) -> t.Generator[t.Union[sa.orm.Session, sa.Connection], None, None]: + if mode == "orm": + with self.Session() as session: + with session.begin(): + yield session + elif mode == "core": + with self.engine.connect() as connection: + with connection.begin(): + yield connection + else: + raise ValueError(f"Invalid transaction mode `{mode}`") + + def test_transaction(self, savepoint: bool = False) -> TestTransaction: + return TestTransaction(self, savepoint=savepoint) + @contextmanager def context( self, engine_execution_options: t.Optional[t.Dict[str, t.Any]] = None, session_execution__options: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Generator[BindContext, None, None]: - context = BindContext(self.config, self.metadata) + context = BindContext(f"{self.name}-context", self.url, self.config, self.metadata) context.engine = self.engine.execution_options(**engine_execution_options or {}) context.Session = self.create_session_factory(session_execution__options or {}) context.Session.configure(bind=context.engine) @@ -110,7 +214,7 @@ def context( ) def create_session_factory( - self, options: dict[str, t.Any] + self, options: t.Dict[str, t.Any] ) -> sa.orm.sessionmaker[sa.orm.Session]: signals.before_bind_session_factory_created.send(self, options=options) session_factory = sa.orm.sessionmaker(bind=self.engine, **options) @@ -125,9 +229,6 @@ def create_engine(self, config: t.Dict[str, t.Any], prefix: str = "") -> sa.Engi signals.after_bind_engine_created.send(self, config=config, prefix=prefix, engine=engine) return engine - def test_transaction(self, savepoint: bool = False): - return TestTransaction(self, savepoint=savepoint) - def _call_metadata(self, method: str): with self.engine.connect() as conn: with conn.begin(): @@ -142,14 +243,27 @@ def drop_all(self): def reflect(self): return self._call_metadata("reflect") - def __repr__(self) -> str: - return f"<{type(self).__name__} {self.engine.url}>" - class AsyncBind(Bind): engine: sa.ext.asyncio.AsyncEngine Session: sa.ext.asyncio.async_sessionmaker + @asynccontextmanager + async def transaction(self, mode: SqlAMode = "orm"): + if mode == "orm": + async with self.Session() as session: + async with session.begin(): + yield session + elif mode == "core": + async with self.engine.connect() as connection: + async with connection.begin(): + yield connection + else: + raise ValueError(f"Invalid transaction mode `{mode}`") + + def test_transaction(self, savepoint: bool = False): + return AsyncTestTransaction(self, savepoint=savepoint) + def create_session_factory( self, options: dict[str, t.Any] ) -> sa.ext.asyncio.async_sessionmaker[sa.ext.asyncio.AsyncSession]: @@ -182,9 +296,6 @@ def create_engine( signals.after_bind_engine_created.send(self, config=config, prefix=prefix, engine=engine) return engine - def test_transaction(self, savepoint: bool = False): - return AsyncTestTransaction(self, savepoint=savepoint) - async def _call_metadata(self, method: str): async with self.engine.connect() as conn: async with conn.begin(): diff --git a/src/quart_sqlalchemy/config.py b/src/quart_sqlalchemy/config.py index 5caa0ec..0190160 100644 --- a/src/quart_sqlalchemy/config.py +++ b/src/quart_sqlalchemy/config.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import os import types import typing as t @@ -11,6 +10,7 @@ import sqlalchemy.ext import sqlalchemy.ext.asyncio import sqlalchemy.orm +import sqlalchemy.sql.sqltypes import sqlalchemy.util import typing_extensions as tx from pydantic import BaseModel @@ -22,6 +22,8 @@ from .model import Base from .types import BoundParamStyle from .types import DMLStrategy +from .types import Empty +from .types import EmptyType from .types import SessionBind from .types import SessionBindKey from .types import SynchronizeSession @@ -63,9 +65,9 @@ class ConfigBase(BaseModel): class Config: arbitrary_types_allowed = True - @classmethod - def default(cls): - return cls() + @root_validator + def scrub_empty(cls, values): + return {key: val for key, val in values.items() if val not in [Empty, {}]} class CoreExecutionOptions(ConfigBase): @@ -73,15 +75,15 @@ class CoreExecutionOptions(ConfigBase): https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Connection.execution_options """ - isolation_level: t.Optional[TransactionIsolationLevel] = None - compiled_cache: t.Optional[t.Dict[t.Any, Compiled]] = Field(default_factory=dict) - logging_token: t.Optional[str] = None - no_parameters: bool = False - stream_results: bool = False - max_row_buffer: int = 1000 - yield_per: t.Optional[int] = None - insertmanyvalues_page_size: int = 1000 - schema_translate_map: t.Optional[t.Dict[str, str]] = None + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + compiled_cache: t.Union[t.Dict[t.Any, Compiled], None, EmptyType] = Empty + logging_token: t.Union[str, None, EmptyType] = Empty + no_parameters: t.Union[bool, EmptyType] = Empty + stream_results: t.Union[bool, EmptyType] = Empty + max_row_buffer: t.Union[int, EmptyType] = Empty + yield_per: t.Union[int, None, EmptyType] = Empty + insertmanyvalues_page_size: t.Union[int, EmptyType] = Empty + schema_translate_map: t.Union[t.Dict[str, str], None, EmptyType] = Empty class ORMExecutionOptions(ConfigBase): @@ -89,14 +91,21 @@ class ORMExecutionOptions(ConfigBase): https://docs.sqlalchemy.org/en/20/orm/queryguide/api.html#orm-queryguide-execution-options """ - isolation_level: t.Optional[TransactionIsolationLevel] = None - stream_results: bool = False - yield_per: t.Optional[int] = None - populate_existing: bool = False - autoflush: bool = True - identity_token: t.Optional[str] = None - synchronize_session: SynchronizeSession = "auto" - dml_strategy: DMLStrategy = "auto" + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + stream_results: t.Union[bool, EmptyType] = Empty + yield_per: t.Union[int, None, EmptyType] = Empty + populate_existing: t.Union[bool, EmptyType] = Empty + autoflush: t.Union[bool, EmptyType] = Empty + identity_token: t.Union[str, None, EmptyType] = Empty + synchronize_session: t.Union[SynchronizeSession, None, EmptyType] = Empty + dml_strategy: t.Union[DMLStrategy, None, EmptyType] = Empty + + +# connect_args: +# mysql: +# connect_timeout: +# postgres: +# connect_timeout: class EngineConfig(ConfigBase): @@ -104,42 +113,55 @@ class EngineConfig(ConfigBase): https://docs.sqlalchemy.org/en/20/core/engines.html#sqlalchemy.create_engine """ - url: t.Union[sa.URL, str] = "sqlite://" - echo: bool = False - echo_pool: bool = False - connect_args: t.Dict[str, t.Any] = Field(default_factory=dict) + url: t.Union[sa.URL, str, EmptyType] = Empty + echo: t.Union[bool, EmptyType] = Empty + echo_pool: t.Union[bool, EmptyType] = Empty + connect_args: t.Union[t.Dict[str, t.Any], EmptyType] = Empty execution_options: CoreExecutionOptions = Field(default_factory=CoreExecutionOptions) - enable_from_linting: bool = True - hide_parameters: bool = False - insertmanyvalues_page_size: int = 1000 - isolation_level: t.Optional[TransactionIsolationLevel] = None - json_deserializer: t.Callable[[str], t.Any] = json.loads - json_serializer: t.Callable[[t.Any], str] = json.dumps - label_length: t.Optional[int] = None - logging_name: t.Optional[str] = None - max_identifier_length: t.Optional[int] = None - max_overflow: int = 10 - module: t.Optional[types.ModuleType] = None - paramstyle: t.Optional[BoundParamStyle] = None - pool: t.Optional[sa.Pool] = None - poolclass: t.Optional[t.Type[sa.Pool]] = None - pool_logging_name: t.Optional[str] = None - pool_pre_ping: bool = False - pool_size: int = 5 - pool_recycle: int = -1 - pool_reset_on_return: t.Optional[tx.Literal["values", "rollback"]] = None - pool_timeout: int = 40 - pool_use_lifo: bool = False - plugins: t.Sequence[str] = Field(default_factory=list) - query_cache_size: int = 500 - user_insertmanyvalues: bool = True + enable_from_linting: t.Union[bool, EmptyType] = Empty + hide_parameters: t.Union[bool, EmptyType] = Empty + insertmanyvalues_page_size: t.Union[int, EmptyType] = Empty + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + json_deserializer: t.Union[t.Callable[[str], t.Any], EmptyType] = Empty + json_serializer: t.Union[t.Callable[[t.Any], str], EmptyType] = Empty + label_length: t.Union[int, None, EmptyType] = Empty + logging_name: t.Union[str, None, EmptyType] = Empty + max_identifier_length: t.Union[int, None, EmptyType] = Empty + max_overflow: t.Union[int, EmptyType] = Empty + module: t.Union[types.ModuleType, None, EmptyType] = Empty + paramstyle: t.Union[BoundParamStyle, None, EmptyType] = Empty + pool: t.Union[sa.Pool, None, EmptyType] = Empty + poolclass: t.Union[t.Type[sa.Pool], None, EmptyType] = Empty + pool_logging_name: t.Union[str, None, EmptyType] = Empty + pool_pre_ping: t.Union[bool, EmptyType] = Empty + pool_size: t.Union[int, EmptyType] = Empty + pool_recycle: t.Union[int, EmptyType] = Empty + pool_reset_on_return: t.Union[tx.Literal["values", "rollback"], None, EmptyType] = Empty + pool_timeout: t.Union[int, EmptyType] = Empty + pool_use_lifo: t.Union[bool, EmptyType] = Empty + plugins: t.Union[t.Sequence[str], EmptyType] = Empty + query_cache_size: t.Union[int, EmptyType] = Empty + user_insertmanyvalues: t.Union[bool, EmptyType] = Empty - @classmethod - def default(cls): - return cls(url="sqlite://") + @root_validator + def scrub_execution_options(cls, values): + if "execution_options" in values: + execute_options = values["execution_options"].dict(exclude_defaults=True) + if execute_options: + values["execution_options"] = execute_options + return values + + @root_validator + def set_defaults(cls, values): + values.setdefault("url", "sqlite://") + return values @root_validator def apply_driver_defaults(cls, values): + # values["execution_options"] = values["execution_options"].dict(exclude_defaults=True) + # values = {key: val for key, val in values.items() if val not in [Empty, {}]} + # values.setdefault("url", "sqlite://") + url = sa.make_url(values["url"]) driver = url.drivername @@ -177,19 +199,31 @@ def apply_driver_defaults(cls, values): return values +class AsyncEngineConfig(EngineConfig): + @root_validator + def set_defaults(cls, values): + values.setdefault("url", "sqlite+aiosqlite://") + return values + + class SessionOptions(ConfigBase): """ https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session """ - autoflush: bool = True - autobegin: bool = True - expire_on_commit: bool = False - bind: t.Optional[SessionBind] = None - binds: t.Optional[t.Dict[SessionBindKey, SessionBind]] = None - twophase: bool = False - info: t.Optional[t.Dict[t.Any, t.Any]] = None - join_transaction_mode: JoinTransactionMode = "conditional_savepoint" + autoflush: t.Union[bool, EmptyType] = Empty + autobegin: t.Union[bool, EmptyType] = Empty + expire_on_commit: t.Union[bool, EmptyType] = Empty + bind: t.Union[SessionBind, None, EmptyType] = Empty + binds: t.Union[t.Dict[SessionBindKey, SessionBind], None, EmptyType] = Empty + twophase: t.Union[bool, EmptyType] = Empty + info: t.Union[t.Dict[t.Any, t.Any], None, EmptyType] = Empty + join_transaction_mode: t.Union[JoinTransactionMode, EmptyType] = Empty + + @root_validator + def set_defaults(cls, values): + values.setdefault("expire_on_commit", False) + return values class SessionmakerOptions(SessionOptions): @@ -197,7 +231,7 @@ class SessionmakerOptions(SessionOptions): https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.sessionmaker """ - class_: t.Type[sa.orm.Session] = sa.orm.Session + class_: t.Union[t.Type[sa.orm.Session], EmptyType] = Empty class AsyncSessionOptions(SessionOptions): @@ -205,7 +239,7 @@ class AsyncSessionOptions(SessionOptions): https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncSession """ - sync_session_class: t.Type[sa.orm.Session] = sa.orm.Session + sync_session_class: t.Union[t.Type[sa.orm.Session], EmptyType] = Empty class AsyncSessionmakerOptions(AsyncSessionOptions): @@ -213,13 +247,14 @@ class AsyncSessionmakerOptions(AsyncSessionOptions): https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.async_sessionmaker """ - class_: t.Type[sa.ext.asyncio.AsyncSession] = sa.ext.asyncio.AsyncSession + class_: t.Union[t.Type[sa.ext.asyncio.AsyncSession], EmptyType] = Empty class BindConfig(ConfigBase): read_only: bool = False - session: SessionmakerOptions = Field(default_factory=SessionmakerOptions.default) - engine: EngineConfig = Field(default_factory=EngineConfig.default) + session: SessionmakerOptions = Field(default_factory=SessionmakerOptions) + engine: EngineConfig = Field(default_factory=EngineConfig) + track_instance: bool = False @root_validator def validate_dialect(cls, values): @@ -227,30 +262,24 @@ def validate_dialect(cls, values): class AsyncBindConfig(BindConfig): - session: AsyncSessionmakerOptions = Field(default_factory=AsyncSessionmakerOptions.default) + session: AsyncSessionmakerOptions = Field(default_factory=AsyncSessionmakerOptions) + engine: AsyncEngineConfig = Field(default_factory=AsyncEngineConfig) @root_validator def validate_dialect(cls, values): return validate_dialect(cls, values, "async") -def default(): - dict(default=dict()) - - class SQLAlchemyConfig(ConfigBase): - class Meta: - web_config_field_map = { - "SQLALCHEMY_MODEL_CLASS": "model_class", - "SQLALCHEMY_BINDS": "binds", - } + base_class: t.Type[t.Any] = Base + binds: t.Dict[str, t.Union[BindConfig, AsyncBindConfig]] = Field(default_factory=dict) - model_class: t.Type[t.Any] = Base - binds: t.Dict[str, t.Union[AsyncBindConfig, BindConfig]] = Field( - default_factory=lambda: dict(default=BindConfig()) - ) + @root_validator + def set_default_bind(cls, values): + values.setdefault("binds", dict(default=BindConfig())) + return values @classmethod - def from_framework(cls, values: t.Dict[str, t.Any]): - key_map = cls.Meta.web_config_field_map - return cls(**{key_map.get(key, key): val for key, val in values.items()}) + def from_framework(cls, framework_config): + config = framework_config.get_namespace("SQLALCHEMY_") + return cls.parse_obj(config or {}) diff --git a/src/quart_sqlalchemy/framework/cli.py b/src/quart_sqlalchemy/framework/cli.py index 13ac0af..6f48455 100644 --- a/src/quart_sqlalchemy/framework/cli.py +++ b/src/quart_sqlalchemy/framework/cli.py @@ -1,28 +1,83 @@ import json import sys +import typing as t import urllib.parse import click -from quart import current_app from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo + +from quart_sqlalchemy import signals + + +if t.TYPE_CHECKING: + from quart_sqlalchemy.framework import QuartSQLAlchemy db_cli = AppGroup("db") +fixtures_cli = AppGroup("fixtures") -@db_cli.command("info", with_appcontext=True) +@db_cli.command("info") +@pass_script_info @click.option("--uri-only", is_flag=True, default=False, help="Only output the connection uri") -def db_info(uri_only=False): - db = current_app.extensions["sqlalchemy"].db - uri = urllib.parse.unquote(str(db.engine.url)) - db_info = dict(db.engine.url._asdict()) +def db_info(info: ScriptInfo, uri_only=False): + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + uri = urllib.parse.unquote(str(db.bind.url)) + info = dict(db.bind.url._asdict()) if uri_only: click.echo(uri) sys.exit(0) click.echo("Database Connection Info") - click.echo(json.dumps(db_info, indent=2)) + click.echo(json.dumps(info, indent=2)) click.echo("\n") click.echo("Connection URI") click.echo(uri) + + +@db_cli.command("create") +@pass_script_info +def create(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.create_all() + + click.echo(f"Initialized database schema for {db}") + + +@db_cli.command("drop") +@pass_script_info +def drop(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.drop_all() + + click.echo(f"Dropped database schema for {db}") + + +@db_cli.command("recreate") +@pass_script_info +def recreate(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.drop_all() + db.create_all() + + click.echo(f"Recreated database schema for {db}") + + +@fixtures_cli.command("load") +@pass_script_info +def load(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + signals.framework_extension_load_fixtures.send(sender=db, app=app) + + click.echo(f"Loaded database fixtures for {db}") + + +db_cli.add_command(fixtures_cli) diff --git a/src/quart_sqlalchemy/framework/extension.py b/src/quart_sqlalchemy/framework/extension.py index 4d29d74..70f0737 100644 --- a/src/quart_sqlalchemy/framework/extension.py +++ b/src/quart_sqlalchemy/framework/extension.py @@ -11,10 +11,12 @@ class QuartSQLAlchemy(SQLAlchemy): def __init__( self, - config: SQLAlchemyConfig, + config: t.Optional[SQLAlchemyConfig] = None, app: t.Optional[Quart] = None, ): - super().__init__(config) + initialize = False if config is None else True + super().__init__(config, initialize=initialize) + if app is not None: self.init_app(app) @@ -24,15 +26,21 @@ def init_app(self, app: Quart) -> None: f"A {type(self).__name__} instance has already been registered on this app" ) + if self.config is None: + self.config = SQLAlchemyConfig.from_framework(app.config) + self.initialize() + signals.before_framework_extension_initialization.send(self, app=app) app.extensions["sqlalchemy"] = self @app.shell_context_processor def export_sqlalchemy_objects(): + nonlocal self + return dict( db=self, - **{m.class_.__name__: m.class_ for m in self.Model._sa_registry.mappers}, + **{m.class_.__name__: m.class_ for m in self.Base.registry.mappers}, ) app.cli.add_command(db_cli) diff --git a/src/quart_sqlalchemy/model/__init__.py b/src/quart_sqlalchemy/model/__init__.py index cf7a4ea..fce0fbf 100644 --- a/src/quart_sqlalchemy/model/__init__.py +++ b/src/quart_sqlalchemy/model/__init__.py @@ -1,7 +1,9 @@ -from .columns import CreatedTimestamp +from .columns import Created +from .columns import IntPK from .columns import Json -from .columns import PrimaryKey -from .columns import UpdatedTimestamp +from .columns import ULID +from .columns import Updated +from .columns import UUID from .custom_types import PydanticType from .custom_types import TZDateTime from .mixins import DynamicArgsMixin @@ -15,3 +17,6 @@ from .mixins import TimestampMixin from .mixins import VersionMixin from .model import Base +from .model import BaseMixins +from .model import default_metadata_naming_convention +from .model import default_type_annotation_map diff --git a/src/quart_sqlalchemy/model/columns.py b/src/quart_sqlalchemy/model/columns.py index c00d9eb..d1bfd94 100644 --- a/src/quart_sqlalchemy/model/columns.py +++ b/src/quart_sqlalchemy/model/columns.py @@ -2,6 +2,8 @@ import typing as t from datetime import datetime +from uuid import UUID +from uuid import uuid4 import sqlalchemy import sqlalchemy.event @@ -12,21 +14,24 @@ import sqlalchemy.util import sqlalchemy_utils import typing_extensions as tx +from ulid import ULID sa = sqlalchemy sau = sqlalchemy_utils +IntPK = tx.Annotated[int, sa.orm.mapped_column(primary_key=True, autoincrement=True)] +UUID = tx.Annotated[UUID, sa.orm.mapped_column(default=uuid4)] +ULID = tx.Annotated[ULID, sa.orm.mapped_column(default=ULID)] -PrimaryKey = tx.Annotated[int, sa.orm.mapped_column(sa.Identity(), primary_key=True)] -CreatedTimestamp = tx.Annotated[ +Created = tx.Annotated[ datetime, sa.orm.mapped_column( default=sa.func.now(), server_default=sa.FetchedValue(), ), ] -UpdatedTimestamp = tx.Annotated[ +Updated = tx.Annotated[ datetime, sa.orm.mapped_column( default=sa.func.now(), @@ -35,7 +40,8 @@ server_onupdate=sa.FetchedValue(), ), ] + Json = tx.Annotated[ t.Dict[t.Any, t.Any], - sa.orm.mapped_column(sau.JSONType, default_factory=dict), + sa.orm.mapped_column(sau.JSONType, default=dict), ] diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index 55c0c3f..d08f446 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -10,6 +10,7 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +import typing_extensions as tx from sqlalchemy.orm import Mapped from ..util import camel_to_snake_case @@ -18,14 +19,37 @@ sa = sqlalchemy +class ORMModel(tx.Protocol): + __table__: sa.Table + + +class SerializingModel(ORMModel): + __table__: sa.Table + + def to_dict( + self: ORMModel, + obj: t.Optional[t.Any] = None, + max_depth: int = 3, + _children_seen: t.Optional[set] = None, + _relations_seen: t.Optional[set] = None, + ) -> t.Dict[str, t.Any]: + ... + + class TableNameMixin: + __abstract__ = True + __table__: sa.Table + @sa.orm.declared_attr.directive - def __tablename__(cls) -> str: + def __tablename__(cls: t.Type[ORMModel]) -> str: return camel_to_snake_case(cls.__name__) class ReprMixin: - def __repr__(self) -> str: + __abstract__ = True + __table__: sa.Table + + def __repr__(self: ORMModel) -> str: state = sa.inspect(self) if state is None: return super().__repr__() @@ -41,7 +65,10 @@ def __repr__(self) -> str: class ComparableMixin: - def __eq__(self, other): + __abstract__ = True + __table__: sa.Table + + def __eq__(self: ORMModel, other: ORMModel) -> bool: if type(self).__name__ != type(other).__name__: return False @@ -55,37 +82,38 @@ def __eq__(self, other): class TotalOrderMixin: - def __lt__(self, other): - if type(self).__name__ != type(other).__name__: - return False + __abstract__ = True + __table__: sa.Table - for key, column in sa.inspect(type(self)).columns.items(): - if column.primary_key: - continue + def __lt__(self: ORMModel, other: ORMModel) -> bool: + if type(self).__name__ != type(other).__name__: + raise NotImplemented - if not (getattr(self, key) == getattr(other, key)): - return False - return True + primary_keys = sa.inspect(type(self)).primary_key + self_keys = [getattr(self, col.name) for col in primary_keys] + other_keys = [getattr(other, col.name) for col in primary_keys] + return self_keys < other_keys class SimpleDictMixin: __abstract__ = True __table__: sa.Table - def to_dict(self): + def to_dict(self) -> t.Dict[str, t.Any]: return {c.name: getattr(self, c.name) for c in self.__table__.columns} class RecursiveDictMixin: __abstract__ = True + __table__: sa.Table - def model_to_dict( - self, + def to_dict( + self: tx.Self, obj: t.Optional[t.Any] = None, - max_depth: int = 3, + max_depth: int = 1, _children_seen: t.Optional[set] = None, _relations_seen: t.Optional[set] = None, - ): + ) -> t.Dict[str, t.Any]: """Convert model to python dict, with recursion. Args: @@ -106,11 +134,7 @@ def model_to_dict( mapper = sa.inspect(obj).mapper columns = [column.key for column in mapper.columns] - get_key_value = ( - lambda c: (c, getattr(obj, c).isoformat()) - if isinstance(getattr(obj, c), datetime) - else (c, getattr(obj, c)) - ) + get_key_value = lambda c: (c, getattr(obj, c)) data = dict(map(get_key_value, columns)) if max_depth > 0: @@ -125,10 +149,12 @@ def model_to_dict( if relationship_children is not None: if relation.uselist: children = [] - for child in (c for c in relationship_children if c not in _children_seen): - _children_seen.add(child) + for child in ( + c for c in relationship_children if repr(c) not in _children_seen + ): + _children_seen.add(repr(child)) children.append( - self.model_to_dict( + self.to_dict( child, max_depth=max_depth - 1, _children_seen=_children_seen, @@ -137,7 +163,7 @@ def model_to_dict( ) data[name] = children else: - data[name] = self.model_to_dict( + data[name] = self.to_dict( relationship_children, max_depth=max_depth - 1, _children_seen=_children_seen, @@ -148,6 +174,9 @@ def model_to_dict( class IdentityMixin: + __abstract__ = True + __table__: sa.Table + id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True) @@ -191,21 +220,29 @@ class User(db.Model, SoftDeleteMixin): """ __abstract__ = True + __table__: sa.Table is_active: Mapped[bool] = sa.orm.mapped_column(default=True) class TimestampMixin: __abstract__ = True + __table__: sa.Table - created_at: Mapped[datetime] = sa.orm.mapped_column(default=sa.func.now()) + created_at: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), server_default=sa.FetchedValue() + ) updated_at: Mapped[datetime] = sa.orm.mapped_column( - default=sa.func.now(), onupdate=sa.func.now() + default=sa.func.now(), + onupdate=sa.func.now(), + server_default=sa.FetchedValue(), + server_onupdate=sa.FetchedValue(), ) class VersionMixin: __abstract__ = True + __table__: sa.Table version_id: Mapped[int] = sa.orm.mapped_column(nullable=False) @@ -222,6 +259,7 @@ class EagerDefaultsMixin: """ __abstract__ = True + __table__: sa.Table @sa.orm.declared_attr.directive def __mapper_args__(cls) -> dict[str, t.Any]: @@ -289,6 +327,7 @@ def accumulate_tuples_with_mapping(class_, attribute) -> t.Sequence[t.Any]: class DynamicArgsMixin: __abstract__ = True + __table__: sa.Table @sa.orm.declared_attr.directive def __mapper_args__(cls) -> t.Dict[str, t.Any]: diff --git a/src/quart_sqlalchemy/model/model.py b/src/quart_sqlalchemy/model/model.py index cd82457..313b6b3 100644 --- a/src/quart_sqlalchemy/model/model.py +++ b/src/quart_sqlalchemy/model/model.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import uuid import sqlalchemy import sqlalchemy.event @@ -10,20 +11,49 @@ import sqlalchemy.orm import sqlalchemy.util import typing_extensions as tx +from sqlalchemy_utils import JSONType from .mixins import ComparableMixin from .mixins import DynamicArgsMixin +from .mixins import EagerDefaultsMixin +from .mixins import RecursiveDictMixin from .mixins import ReprMixin from .mixins import TableNameMixin +from .mixins import TotalOrderMixin sa = sqlalchemy - -class Base(DynamicArgsMixin, ReprMixin, ComparableMixin, TableNameMixin): +default_metadata_naming_convention = { + "ix": "ix_%(column_0_label)s", # INDEX + "uq": "uq_%(table_name)s_%(column_0_N_name)s", # UNIQUE + "ck": "ck_%(table_name)s_%(constraint_name)s", # CHECK + "fk": "fk_%(table_name)s_%(column_0_N_name)s_%(referred_table_name)s", # FOREIGN KEY + "pk": "pk_%(table_name)s", # PRIMARY KEY +} + +default_type_annotation_map = { + enum.Enum: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), + tx.Literal: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), + uuid.UUID: sa.Uuid, + dict: JSONType, +} + + +class BaseMixins( + DynamicArgsMixin, + EagerDefaultsMixin, + ReprMixin, + RecursiveDictMixin, + TotalOrderMixin, + ComparableMixin, + TableNameMixin, +): __abstract__ = True + __table__: sa.Table + - type_annotation_map = { - enum.Enum: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), - tx.Literal: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), - } +class Base(BaseMixins, sa.orm.DeclarativeBase): + __abstract__ = True + metadata = sa.MetaData(naming_convention=default_metadata_naming_convention) + type_annotation_map = default_type_annotation_map diff --git a/src/quart_sqlalchemy/retry.py b/src/quart_sqlalchemy/retry.py index 664dc66..8353dc3 100644 --- a/src/quart_sqlalchemy/retry.py +++ b/src/quart_sqlalchemy/retry.py @@ -99,6 +99,7 @@ async def add_user_post(db, user_id, post_values): import sqlalchemy.exc import sqlalchemy.orm import tenacity +from tenacity import RetryError sa = sqlalchemy diff --git a/src/quart_sqlalchemy/session.py b/src/quart_sqlalchemy/session.py index 4833a5c..36db6c3 100644 --- a/src/quart_sqlalchemy/session.py +++ b/src/quart_sqlalchemy/session.py @@ -1,6 +1,9 @@ from __future__ import annotations import typing as t +from contextlib import contextmanager +from contextvars import ContextVar +from functools import wraps import sqlalchemy import sqlalchemy.exc @@ -18,6 +21,56 @@ sa = sqlalchemy +""" +Requirements: + * a global context var session + * a context manager that sets the session value and manages its lifetime + * a factory that will always return the current session value + * a decorator that will inject the current session value +""" + +_global_contextual_session = ContextVar("_global_contextual_session") + + +@contextmanager +def set_global_contextual_session(session, bind=None): + token = _global_contextual_session.set(session) + try: + yield + finally: + _global_contextual_session.reset(token) + + +def provide_global_contextual_session(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + session_in_args = any( + [isinstance(arg, (sa.orm.Session, sa.ext.asyncio.AsyncSession)) for arg in args] + ) + session_in_kwargs = "session" in kwargs + session_provided = session_in_args or session_in_kwargs + + if session_provided: + return func(self, *args, **kwargs) + else: + session = session_proxy() + + return func(self, session, *args, **kwargs) + + return wrapper + + +class SessionProxy: + def __call__(self) -> t.Union[sa.orm.Session, sa.ext.asyncio.AsyncSession]: + return _global_contextual_session.get() + + def __getattr__(self, name): + return getattr(self(), name) + + +session_proxy = SessionProxy() + + class Session(sa.orm.Session, t.Generic[EntityT, EntityIdT]): """A SQLAlchemy :class:`~sqlalchemy.orm.Session` class. diff --git a/src/quart_sqlalchemy/signals.py b/src/quart_sqlalchemy/signals.py index 3015b16..ecef762 100644 --- a/src/quart_sqlalchemy/signals.py +++ b/src/quart_sqlalchemy/signals.py @@ -113,3 +113,34 @@ def handle(sender: QuartSQLAlchemy, app: Quart): ... """, ) + + +framework_extension_load_fixtures = sync_signals.signal( + "quart-sqlalchemy.framework.extension.fixtures.load", + doc="""Fired to load fixtures into a fresh database. + + No default signal handlers exist for this signal as the logic is very application dependent. + This signal handler is typically triggered using the CLI: + + $ quart db fixtures load + + Example: + + @signals.framework_extension_load_fixtures.connect + def handle(sender: QuartSQLAlchemy, app: Quart): + db = sender.get_bind("default") + with db.Session() as session: + with session.begin(): + session.add_all( + [ + models.User(username="user1"), + models.User(username="user2"), + ] + ) + session.commit() + + Handler signature: + def handle(sender: QuartSQLAlchemy, app: Quart): + ... + """, +) diff --git a/src/quart_sqlalchemy/sim/__init__.py b/src/quart_sqlalchemy/sim/__init__.py new file mode 100644 index 0000000..6a3f395 --- /dev/null +++ b/src/quart_sqlalchemy/sim/__init__.py @@ -0,0 +1,2 @@ +from . import app +from . import model diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py new file mode 100644 index 0000000..c138f42 --- /dev/null +++ b/src/quart_sqlalchemy/sim/app.py @@ -0,0 +1,41 @@ +import logging +import typing as t +from copy import deepcopy + +from quart import Quart +from quart_schema import QuartSchema +from werkzeug.utils import import_string + +from .config import settings +from .container import Container + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +schema = QuartSchema(security_schemes=settings.SECURITY_SCHEMES) + + +def create_app(override_config: t.Optional[t.Dict[str, t.Any]] = None): + override_config = override_config or {} + + config = deepcopy(settings.dict()) + config.update(override_config) + + app = Quart(__name__) + app.config.from_mapping(config) + app.config.from_prefixed_env() + + for path in app.config["LOAD_EXTENSIONS"]: + extension = import_string(path) + extension.init_app(app) + + for path in app.config["LOAD_BLUEPRINTS"]: + bp = import_string(path) + app.register_blueprint(bp) + + container = Container(app=app) + app.container = container + + return app diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py new file mode 100644 index 0000000..dece2ba --- /dev/null +++ b/src/quart_sqlalchemy/sim/auth.py @@ -0,0 +1,313 @@ +import logging +import re +import secrets +import typing as t + +import click +import sqlalchemy +import sqlalchemy.orm +import sqlalchemy.orm.exc +from quart import current_app +from quart import g +from quart import Quart +from quart import request +from quart import Request +from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo +from quart_schema.extension import QUART_SCHEMA_SECURITY_ATTRIBUTE +from quart_schema.extension import security_scheme +from quart_schema.openapi import APIKeySecurityScheme +from quart_schema.openapi import HttpSecurityScheme +from quart_schema.openapi import SecuritySchemeBase +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from .model import AuthUser +from .model import EntityType +from .model import MagicClient +from .model import Provenance +from .schema import BaseSchema +from .util import ObjectID + + +sa = sqlalchemy + +cli = AppGroup("auth") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def authorized_request(security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]]): + def decorator(func): + return security_scheme(security_schemes)(func) + + return decorator + + +class MyRequest(Request): + @property + def ip_addr(self): + return self.remote_addr + + @property + def locale(self): + return self.accept_languages.best_match(["en"]) or "en" + + @property + def redirect_url(self): + return self.args.get("redirect_url") or self.headers.get("x-redirect-url") + + +class ValidatorError(RuntimeError): + pass + + +class SubjectNotFound(ValidatorError): + pass + + +class CredentialNotFound(ValidatorError): + pass + + +class Credential(BaseSchema): + scheme: SecuritySchemeBase + value: t.Optional[str] = None + subject: t.Union[MagicClient, AuthUser] + + +class AuthenticationValidator: + name: str + scheme: SecuritySchemeBase + + def extract(self, request: Request) -> str: + ... + + def lookup(self, value: str, session: Session) -> t.Any: + ... + + def authenticate(self, request: Request) -> Credential: + ... + + +class PublicAPIKeyValidator(AuthenticationValidator): + name = "public-api-key" + scheme = APIKeySecurityScheme(in_="header", name="X-Public-API-Key") + + def extract(self, request: Request) -> str: + if self.scheme.in_ == "header": + return request.headers.get(self.scheme.name, None) + elif self.scheme.in_ == "cookie": + return request.cookies.get(self.scheme.name, None) + elif self.scheme.in_ == "query": + return request.args.get(self.scheme.name, None) + else: + raise ValueError(f"No token found for {self.scheme}") + + def lookup(self, value: str, session: Session) -> t.Any: + statement = sa.select(MagicClient).where(MagicClient.public_api_key == value).limit(1) + + try: + result = session.scalars(statement).one() + except sa.orm.exc.NoResultFound: + raise SubjectNotFound(f"No MagicClient found for public_api_key {value}") + + return result + + def authenticate(self, request: Request, session: Session) -> Credential: + value = self.extract(request) + if value is None: + raise CredentialNotFound() + subject = self.lookup(value, session) + return Credential(scheme=self.scheme, value=value, subject=subject) + + +class SessionTokenValidator(AuthenticationValidator): + name = "session-token-bearer" + scheme = HttpSecurityScheme(scheme="bearer", bearer_format="opaque") + + AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") + + def extract(self, request: Request) -> str: + if self.scheme.scheme != "bearer": + return + + value = request.headers.get("authorization") + m = self.AUTHORIZATION_PATTERN.match(value) + if m is None: + raise ValueError("Bearer token failed validation") + + return m.group("token") + + def lookup(self, value: str, session: Session) -> t.Any: + statement = sa.select(AuthUser).where(AuthUser.current_session_token == value).limit(1) + + try: + result = session.scalars(statement).one() + except sa.orm.exc.NoResultFound: + raise SubjectNotFound(f"No AuthUser found for session_token {value}") + + return result + + def authenticate(self, request: Request, session: Session) -> Credential: + value = self.extract(request) + if value is None: + raise CredentialNotFound() + subject = self.lookup(value, session) + return Credential(scheme=self.scheme, value=value, subject=subject) + + +class RequestAuthenticator: + validators = [PublicAPIKeyValidator(), SessionTokenValidator()] + validator_scheme_map = {v.name: v for v in validators} + + def enforce(self, security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]], session: Session): + passed, failed = [], [] + for scheme_credential in self.validate_security(security_schemes, session): + if all(scheme_credential.values()): + passed.append(scheme_credential) + else: + failed.append(scheme_credential) + if passed: + return passed + raise Forbidden() + + def validate_security( + self, security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]], session: Session + ): + if not security_schemes: + return + + for scheme in security_schemes: + scheme_credentials = {} + for name, _ in scheme.items(): + validator = self.validator_scheme_map[name] + credential = None + try: + credential = validator.authenticate(request, session) + except ValidatorError: + pass + except: + logger.exception(f"Unknown error while validating {name}") + raise + finally: + scheme_credentials[name] = credential + yield scheme_credentials + + +class QuartAuth: + authenticator = RequestAuthenticator() + + def __init__(self, app: t.Optional[Quart] = None, bind_name: str = "default"): + self.bind_name = bind_name + + if app is not None: + self.init_app(app) + + def init_app(self, app: Quart): + app.before_request(self.auth_endpoint_security) + + app.request_class = MyRequest + + self.security_schemes = app.config.get("QUART_AUTH_SECURITY_SCHEMES", {}) + app.cli.add_command(cli) + + app.extensions["auth"] = self + + def auth_endpoint_security(self): + db = current_app.extensions.get("sqlalchemy") + view_function = current_app.view_functions[request.endpoint] + security_schemes = getattr(view_function, QUART_SCHEMA_SECURITY_ATTRIBUTE, None) + if security_schemes is None: + g.authorized_credentials = {} + + bind = db.get_bind(self.bind_name) + with bind.Session() as session: + results = self.authenticator.enforce(security_schemes, session) + authorized_credentials = {} + for result in results: + authorized_credentials.update(result) + g.authorized_credentials = authorized_credentials + + +class RequestCredentials: + def __init__(self, request): + self.request = request + + @property + def current_user(self): + return g.authorized_credentials.get("session-token-bearer") + + @property + def current_client(self): + return g.authorized_credentials.get("public-api-key") + + +@cli.command("add-user") +@click.option( + "--email", + type=str, + default="default@none.com", + help="email", +) +@click.option( + "--user-type", + # type=click.Choice(list(EntityType.__members__)), + type=click.Choice(["FORTMATIC", "MAGIC", "CONNECT"]), + default="MAGIC", + help="user type", +) +@click.option( + "--client-id", + type=str, + required=True, + help="client id", +) +@pass_script_info +def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + auth = app.extensions.get("auth") + bind = db.get_bind(auth.bind_name) + with bind.Session() as s: + with s.begin(): + user = AuthUser( + email=email, + user_type=EntityType[user_type].value, + client_id=ObjectID(client_id), + provenance=Provenance.LINK, + current_session_token=secrets.token_hex(16), + ) + s.add(user) + s.flush() + s.refresh(user) + + click.echo(f"Created user {user.id} with session_token: {user.current_session_token}") + + +@cli.command("add-client") +@click.option( + "--name", + type=str, + default="My App", + help="app name", +) +@pass_script_info +def add_client(info: ScriptInfo, name: str) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + auth = app.extensions.get("auth") + bind = db.get_bind(auth.bind_name) + with bind.Session() as s: + with s.begin(): + client = MagicClient(app_name=name, public_api_key=secrets.token_hex(16)) + s.add(client) + s.flush() + s.refresh(client) + + click.echo(f"Created client {client.id} with public_api_key: {client.public_api_key}") + + +auth = QuartAuth() diff --git a/src/quart_sqlalchemy/sim/builder.py b/src/quart_sqlalchemy/sim/builder.py new file mode 100644 index 0000000..5aa96d6 --- /dev/null +++ b/src/quart_sqlalchemy/sim/builder.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +import sqlalchemy.sql +from sqlalchemy.orm.interfaces import ORMOption + +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import DMLTable +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Selectable + + +sa = sqlalchemy + + +class StatementBuilder(t.Generic[EntityT]): + model: t.Optional[t.Type[EntityT]] + + def __init__(self, model: t.Optional[t.Type[EntityT]] = None): + self.model = model + + def select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + ) -> sa.Select: + statement = sa.select(*selectables or self.model).where(*conditions) + + if for_update: + statement = statement.with_for_update() + if offset: + statement = statement.offset(offset) + if limit: + statement = statement.limit(limit) + if group_by: + statement = statement.group_by(*group_by) + if order_by: + statement = statement.order_by(*order_by) + + for option in options: + for context in option.context: + for strategy in context.strategy: + if "joined" in strategy: + distinct = True + + statement = statement.options(option) + + if distinct: + statement = statement.distinct() + + if execution_options: + statement = statement.execution_options(**execution_options) + + return statement + + def insert( + self, + target: t.Optional[DMLTable] = None, + values: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Insert: + return sa.insert(target or self.model).values(**values or {}) + + def bulk_insert( + self, + target: t.Optional[DMLTable] = None, + values: t.Sequence[t.Dict[str, t.Any]] = (), + ) -> sa.Insert: + return sa.insert(target or self.model).values(*values) + + def bulk_update( + self, + target: t.Optional[DMLTable] = None, + conditions: t.Sequence[ColumnExpr] = (), + values: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Update: + return sa.update(target or self.model).where(*conditions).values(**values or {}) + + def bulk_delete( + self, + target: t.Optional[DMLTable] = None, + conditions: t.Sequence[ColumnExpr] = (), + ) -> sa.Delete: + return sa.delete(target or self.model).where(*conditions) diff --git a/src/quart_sqlalchemy/sim/commands.py b/src/quart_sqlalchemy/sim/commands.py new file mode 100644 index 0000000..c2a6567 --- /dev/null +++ b/src/quart_sqlalchemy/sim/commands.py @@ -0,0 +1,52 @@ +import asyncio +import sys + +import click +import IPython +from IPython.terminal.ipapp import load_default_config +from quart import current_app + + +def attach(app): + app.shell_context_processor(app_env) + app.cli.command( + with_appcontext=True, + context_settings=dict( + ignore_unknown_options=True, + ), + )(ishell) + + +def app_env(): + app = current_app + return dict(container=app.container) + + +@click.argument("ipython_args", nargs=-1, type=click.UNPROCESSED) +def ishell(ipython_args): + import nest_asyncio + + nest_asyncio.apply() + + config = load_default_config() + + asyncio.run(current_app.startup()) + + context = current_app.make_shell_context() + + config.TerminalInteractiveShell.banner1 = """Python %s on %s +IPython: %s +App: %s [%s] +""" % ( + sys.version, + sys.platform, + IPython.__version__, + current_app.import_name, + current_app.env, + ) + + IPython.start_ipython( + argv=ipython_args, + user_ns=context, + config=config, + ) diff --git a/src/quart_sqlalchemy/sim/config.py b/src/quart_sqlalchemy/sim/config.py new file mode 100644 index 0000000..d743ee7 --- /dev/null +++ b/src/quart_sqlalchemy/sim/config.py @@ -0,0 +1,55 @@ +import typing as t + +import sqlalchemy +from pydantic import BaseSettings +from pydantic import Field +from pydantic import PyObject +from quart_schema import APIKeySecurityScheme +from quart_schema import HttpSecurityScheme +from quart_schema.openapi import SecuritySchemeBase + +from quart_sqlalchemy import AsyncBindConfig +from quart_sqlalchemy import BindConfig +from quart_sqlalchemy.sim.db import MyBase + + +sa = sqlalchemy + + +class AppSettings(BaseSettings): + class Config: + env_file = ".env", ".secrets.env" + + LOAD_BLUEPRINTS: t.List[str] = Field( + default_factory=lambda: list(("quart_sqlalchemy.sim.views.api",)) + ) + LOAD_EXTENSIONS: t.List[str] = Field( + default_factory=lambda: list( + ( + "quart_sqlalchemy.sim.db.db", + "quart_sqlalchemy.sim.app.schema", + "quart_sqlalchemy.sim.auth.auth", + ) + ) + ) + SECURITY_SCHEMES: t.Dict[str, SecuritySchemeBase] = Field( + default_factory=lambda: { + "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), + "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), + } + ) + + SQLALCHEMY_BINDS: t.Dict[str, t.Union[AsyncBindConfig, BindConfig]] = Field( + default_factory=lambda: dict(default=BindConfig(engine=dict(url="sqlite:///app.db"))) + ) + SQLALCHEMY_BASE_CLASS: t.Type[t.Any] = Field(default=MyBase) + + WEB3_DEFAULT_CHAIN: str = Field(default="ethereum") + WEB3_DEFAULT_NETWORK: str = Field(default="goerli") + + WEB3_PROVIDER_CLASS: PyObject = Field("web3.providers.HTTPProvider", env="WEB3_PROVIDER_CLASS") + ALCHEMY_API_KEY: str = Field(env="ALCHEMY_API_KEY") + WEB3_HTTPS_PROVIDER_URI: str = Field(env="WEB3_HTTPS_PROVIDER_URI") + + +settings = AppSettings() diff --git a/src/quart_sqlalchemy/sim/container.py b/src/quart_sqlalchemy/sim/container.py new file mode 100644 index 0000000..91bb358 --- /dev/null +++ b/src/quart_sqlalchemy/sim/container.py @@ -0,0 +1,57 @@ +import typing as t + +import sqlalchemy.orm +from dependency_injector import containers +from dependency_injector import providers +from quart import request + +from quart_sqlalchemy.session import SessionProxy +from quart_sqlalchemy.sim.auth import RequestCredentials +from quart_sqlalchemy.sim.handle import AuthUserHandler +from quart_sqlalchemy.sim.handle import AuthWalletHandler +from quart_sqlalchemy.sim.handle import MagicClientHandler +from quart_sqlalchemy.sim.logic import LogicComponent + +from .config import AppSettings +from .web3 import Web3 +from .web3 import web3_node_factory + + +sa = sqlalchemy + + +def get_db_from_app(app): + return app.extensions["sqlalchemy"] + + +class Container(containers.DeclarativeContainer): + wiring_config = containers.WiringConfiguration( + modules=[ + "quart_sqlalchemy.sim.views", + "quart_sqlalchemy.sim.logic", + "quart_sqlalchemy.sim.handle", + "quart_sqlalchemy.sim.views.auth_wallet", + "quart_sqlalchemy.sim.views.auth_user", + "quart_sqlalchemy.sim.views.magic_client", + ] + ) + config = providers.Configuration(pydantic_settings=[AppSettings()]) + app = providers.Object() + db = providers.Singleton(get_db_from_app, app=app) + + session_factory = providers.Singleton(SessionProxy) + logic = providers.Singleton(LogicComponent) + + AuthUserHandler = providers.Singleton(AuthUserHandler) + MagicClientHandler = providers.Singleton(MagicClientHandler) + AuthWalletHandler = providers.Singleton(AuthWalletHandler) + + web3_node = providers.Singleton(web3_node_factory, config=config) + web3 = providers.Singleton( + Web3, + node=web3_node, + default_network=config.WEB3_DEFAULT_NETWORK, + default_chain=config.WEB3_DEFAULT_CHAIN, + ) + current_request = providers.Factory(lambda: request) + request_credentials = providers.Singleton(RequestCredentials, request=current_request) diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py new file mode 100644 index 0000000..9634815 --- /dev/null +++ b/src/quart_sqlalchemy/sim/db.py @@ -0,0 +1,106 @@ +import click +import sqlalchemy +import sqlalchemy.orm +from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo +from sqlalchemy.types import Integer +from sqlalchemy.types import TypeDecorator + +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model import BaseMixins +from quart_sqlalchemy.sim.util import ObjectID + + +sa = sqlalchemy +cli = AppGroup("db-schema") + + +def init_fixtures(session): + """Initialize the database with some fixtures.""" + from quart_sqlalchemy.sim.model import AuthUser + from quart_sqlalchemy.sim.model import MagicClient + + client = MagicClient( + app_name="My App", + public_api_key="4700aed5ee9f76f7be6398cd4b00b586", + auth_users=[ + AuthUser( + email="joe@magic.link", + current_session_token="97ee741d53e11a490460927c8a2ce4a3", + ), + ], + ) + session.add(client) + session.flush() + + +class ObjectIDType(TypeDecorator): + """A custom database column type that converts integer value to our ObjectID. + This allows us to pass around ObjectID type in the application for easy + frontend encoding and database decoding on the integer value. + + Note: all id db column type should use this type for its column. + """ + + impl = Integer + cache_ok = False + + def process_bind_param(self, value, dialect): + """Data going into to the database will be transformed by this method. + See ``ObjectID`` for the design and rational for this. + """ + if value is None: + return None + + return ObjectID(value).decode() + + def process_result_value(self, value, dialect): + """Data going out from the database will be explicitly casted to the + ``ObjectID``. + """ + if value is None: + return None + + return ObjectID(value) + + +class MyBase(BaseMixins, sa.orm.DeclarativeBase): + __abstract__ = True + type_annotation_map = {ObjectID: ObjectIDType} + + +@cli.command("load") +@pass_script_info +def schema_load(info: ScriptInfo) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + db.create_all() + + click.echo(f"Initialized database schema for {db}") + + +# sqlite:///file:mem.db?mode=memory&cache=shared&uri=true +db = QuartSQLAlchemy( + SQLAlchemyConfig.parse_obj( + { + "base_class": MyBase, + "binds": { + "default": { + "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": {"url": "sqlite+aiosqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + }, + } + ) +) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py new file mode 100644 index 0000000..bd3bc12 --- /dev/null +++ b/src/quart_sqlalchemy/sim/handle.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import logging +import secrets +import typing as t +from datetime import datetime + +import sqlalchemy +import sqlalchemy.orm +from dependency_injector.wiring import Provide +from quart import Quart + +from quart_sqlalchemy.session import provide_global_contextual_session +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.logic import LogicComponent +from quart_sqlalchemy.sim.model import AuthUser +from quart_sqlalchemy.sim.model import AuthWallet +from quart_sqlalchemy.sim.model import EntityType +from quart_sqlalchemy.sim.model import WalletType +from quart_sqlalchemy.sim.util import ObjectID + + +sa = sqlalchemy + +logger = logging.getLogger(__name__) + +CLIENTS_PER_API_USER_LIMIT = 50 + + +def get_product_type_by_client_id(_): + return EntityType.MAGIC.value + + +class MaxClientsExceeded(Exception): + pass + + +class AuthUserBaseError(Exception): + pass + + +class InvalidSubstringError(AuthUserBaseError): + pass + + +class HandlerBase: + logic: LogicComponent = Provide["logic"] + + +class MagicClientHandler(HandlerBase): + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"] + + @provide_global_contextual_session + def add( + self, + session: sa.orm.Session, + app_name=None, + rate_limit_tier=None, + connect_interop=None, + is_signing_modal_enabled=False, + global_audience_enabled=False, + ): + """Registers a new client. + + Args: + is_magic_connect_enabled (boolean): if True, it will create a Magic Connect app. + + Returns: + A ``MagicClient``. + """ + + return self.logic.MagicClient.add( + session, + app_name=app_name, + rate_limit_tier=rate_limit_tier, + connect_interop=connect_interop, + is_signing_modal_enabled=is_signing_modal_enabled, + global_audience_enabled=global_audience_enabled, + ) + + @provide_global_contextual_session + def get_by_public_api_key(self, session: sa.orm.Session, public_api_key): + return self.logic.MagicClient.get_by_public_api_key(session, public_api_key) + + @provide_global_contextual_session + def get_by_id(self, session: sa.orm.Session, magic_client_id): + return self.logic.MagicClient.get_by_id(session, magic_client_id) + + @provide_global_contextual_session + def update_app_name_by_id(self, session: sa.orm.Session, magic_client_id, app_name): + """ + Args: + magic_client_id (ObjectID|int|str): self explanatory. + app_name (str): Desired application name. + + Returns: + None if `magic_client_id` doesn't exist in the db + app_name if update was successful + """ + client = self.logic.MagicClient.update_by_id(session, magic_client_id, app_name=app_name) + + if not client: + return None + + return client.app_name + + @provide_global_contextual_session + def update_by_id(self, session: sa.orm.Session, magic_client_id, **kwargs): + client = self.logic.MagicClient.update_by_id(session, magic_client_id, **kwargs) + + return client + + @provide_global_contextual_session + def set_inactive_by_id(self, session: sa.orm.Session, magic_client_id): + """ + Args: + magic_client_id (ObjectID|int|str): self explanatory. + + Returns: + None + """ + self.logic.MagicClient.update_by_id(session, magic_client_id, is_active=False) + + @provide_global_contextual_session + def get_users_for_client( + self, + session: sa.orm.Session, + magic_client_id, + offset=None, + limit=None, + ): + """ + Returns emails and signup timestamps for all auth users belonging to a given client + """ + product_type = get_product_type_by_client_id(magic_client_id) + auth_users = self.auth_user_handler.get_by_client_id_and_user_type( + session, + magic_client_id, + product_type, + offset=offset, + limit=limit, + ) + + return { + "users": [ + dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) + for u in auth_users + ] + } + + +class AuthUserHandler(HandlerBase): + @provide_global_contextual_session + def get_by_session_token(self, session: sa.orm.Session, session_token): + return self.logic.AuthUser.get_by_session_token(session, session_token) + + @provide_global_contextual_session + def get_or_create_by_email_and_client_id( + self, + session: sa.orm.Session, + email, + client_id, + user_type=EntityType.MAGIC.value, + ): + with session.begin_nested(): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + session, + email, + client_id, + user_type=user_type, + for_update=True, + ) + if not auth_user: + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + session, + client_id, + email=email, + user_type=user_type, + ) + return auth_user + + @provide_global_contextual_session + def create_verified_user( + self, + session: sa.orm.Session, + client_id, + email, + user_type=EntityType.FORTMATIC.value, + **kwargs, + ): + with session.begin_nested(): + auid = self.logic.AuthUser.add_by_email_and_client_id( + session, + client_id, + email, + user_type=user_type, + **kwargs, + ).id + + session.flush() + + auth_user = self.logic.AuthUser.update_by_id( + session, + auid, + date_verified=datetime.utcnow(), + current_session_token=secrets.token_hex(16), + ) + + return auth_user + + @provide_global_contextual_session + def get_by_id(self, session: sa.orm.Session, auth_user_id) -> AuthUser: + return self.logic.AuthUser.get_by_id(session, auth_user_id) + + @provide_global_contextual_session + def get_by_client_id_and_user_type( + self, + session: sa.orm.Session, + client_id, + user_type, + offset=None, + limit=None, + ): + return self.logic.AuthUser.get_by_client_id_and_user_type( + session, + client_id, + user_type, + offset=offset, + limit=limit, + ) + + @provide_global_contextual_session + def exist_by_email_client_id_and_user_type( + self, session: sa.orm.Session, email, client_id, user_type + ): + return self.logic.AuthUser.exist_by_email_and_client_id( + session, + email, + client_id, + user_type=user_type, + ) + + @provide_global_contextual_session + def update_email_by_id(self, session: sa.orm.Session, model_id, email): + return self.logic.AuthUser.update_by_id(session, model_id, email=email) + + @provide_global_contextual_session + def get_by_email_client_id_and_user_type( + self, session: sa.orm.Session, email, client_id, user_type + ): + return self.logic.AuthUser.get_by_email_and_client_id( + session, + email, + client_id, + user_type, + ) + + @provide_global_contextual_session + def mark_date_verified_by_id(self, session: sa.orm.Session, model_id): + return self.logic.AuthUser.update_by_id( + session, + model_id, + date_verified=datetime.utcnow(), + ) + + @provide_global_contextual_session + def set_role_by_email_magic_client_id( + self, session: sa.orm.Session, email, magic_client_id, role + ): + session = session + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + session, + email, + magic_client_id, + EntityType.MAGIC.value, + for_update=True, + ) + + if not auth_user: + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + session, + magic_client_id, + email, + user_type=EntityType.MAGIC.value, + ) + + session.flush() + + return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True}) + + @provide_global_contextual_session + def mark_as_inactive(self, session: sa.orm.Session, auth_user_id): + self.logic.AuthUser.update_by_id(session, auth_user_id, is_active=False) + + +@signals.auth_user_duplicate.connect +def handle_duplicate_auth_users( + app: Quart, + original_auth_user_id: ObjectID, + duplicate_auth_user_ids: t.Sequence[ObjectID], +) -> None: + for dupe_id in duplicate_auth_user_ids: + app.container.logic().AuthUser.update_by_id(dupe_id, is_active=False) + + +class AuthWalletHandler(HandlerBase): + @provide_global_contextual_session + def get_by_id(self, session: sa.orm.Session, model_id): + return self.logic.AuthWallet.get_by_id(session, model_id) + + @provide_global_contextual_session + def get_by_public_address(self, session: sa.orm.Session, public_address): + return self.logic.AuthWallet().get_by_public_address(session, public_address) + + @provide_global_contextual_session + def get_by_auth_user_id( + self, + session: sa.orm.Session, + auth_user_id: ObjectID, + network: t.Optional[str] = None, + wallet_type: t.Optional[WalletType] = None, + **kwargs, + ) -> t.List[AuthWallet]: + return self.logic.AuthWallet.get_by_auth_user_id( + session, + auth_user_id, + network=network, + wallet_type=wallet_type, + **kwargs, + ) + + @provide_global_contextual_session + def sync_auth_wallet( + self, + session: sa.orm.Session, + auth_user_id, + public_address, + encrypted_private_address, + wallet_management_type, + network: t.Optional[str] = None, + wallet_type: t.Optional[WalletType] = None, + ): + with session.begin_nested(): + existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + session, + auth_user_id, + ) + if existing_wallet: + raise RuntimeError("WalletExistsForNetworkAndWalletType") + + return self.logic.AuthWallet.add( + session, + public_address, + encrypted_private_address, + wallet_type, + network, + management_type=wallet_management_type, + auth_user_id=auth_user_id, + ) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py new file mode 100644 index 0000000..626adbf --- /dev/null +++ b/src/quart_sqlalchemy/sim/logic.py @@ -0,0 +1,480 @@ +import logging +import secrets +import typing as t +from datetime import datetime + +import sqlalchemy +import sqlalchemy.orm +from quart import current_app + +from quart_sqlalchemy.session import provide_global_contextual_session +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.model import AuthUser as auth_user_model +from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model +from quart_sqlalchemy.sim.model import EntityType +from quart_sqlalchemy.sim.model import MagicClient as magic_client_model +from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter +from quart_sqlalchemy.sim.util import ObjectID +from quart_sqlalchemy.sim.util import one +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import SessionT + + +logger = logging.getLogger(__name__) +sa = sqlalchemy + + +class LogicMeta(type): + _ignore = {"LegacyLogicComponent"} + + def __init__(cls, name, bases, cls_dict): + if not hasattr(cls, "_registry"): + cls._registry = {} + else: + if cls.__name__ not in cls._ignore: + model = getattr(cls, "model", None) + if model is not None: + name = model.__name__ + + cls._registry[name] = cls() + + super().__init__(name, bases, cls_dict) + + +class LogicComponent(t.Generic[EntityT, EntityIdT, SessionT], metaclass=LogicMeta): + def __dir__(self): + return super().__dir__() + list(self._registry.keys()) + + def __getattr__(self, name): + if name in self._registry: + return self._registry[name] + else: + raise AttributeError(f"{type(self).__name__} has no attribute '{name}'") + + +class MagicClient(LogicComponent[magic_client_model, ObjectID, sa.orm.Session]): + model = magic_client_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) + + @provide_global_contextual_session + def add(self, session, app_name=None, **kwargs): + public_api_key = secrets.token_hex(16) + return self._repository.add( + session, + app_name=app_name, + **kwargs, + public_api_key=public_api_key, + ) + + @provide_global_contextual_session + def get_by_id( + self, + session, + model_id, + allow_inactive=False, + join_list=None, + ) -> t.Optional[magic_client_model]: + return self._repository.get_by_id( + session, + model_id, + allow_inactive=allow_inactive, + join_list=join_list, + ) + + @provide_global_contextual_session + def get_by_public_api_key( + self, + session, + public_api_key, + ): + return one( + self._repository.get_by( + session, + filters=[magic_client_model.public_api_key == public_api_key], + limit=1, + ) + ) + + @provide_global_contextual_session + def update_by_id(self, session, model_id, **update_params): + modified_row = self._repository.update(session, model_id, **update_params) + session.refresh(modified_row) + return modified_row + + @provide_global_contextual_session + def yield_all_clients_by_chunk(self, session, chunk_size): + yield from self._repository.yield_by_chunk(session, chunk_size) + + @provide_global_contextual_session + def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): + yield from self._repository.yield_by_chunk( + session, + chunk_size, + filters=filters, + join_list=join_list, + ) + + +class DuplicateAuthUser(Exception): + pass + + +class AuthUserDoesNotExist(Exception): + pass + + +class MissingEmail(Exception): + pass + + +class MissingPhoneNumber(Exception): + pass + + +class AuthUser(LogicComponent[auth_user_model, ObjectID, sa.orm.Session]): + model = auth_user_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) + + @provide_global_contextual_session + def add(self, session, **kwargs) -> auth_user_model: + return self._repository.add(session, **kwargs) + + @provide_global_contextual_session + def add_by_email_and_client_id( + self, + session, + client_id, + email=None, + user_type=EntityType.FORTMATIC.value, + **kwargs, + ): + if email is None: + raise MissingEmail() + + if self.exist_by_email_and_client_id( + session, + email, + client_id, + user_type=user_type, + ): + logger.exception( + "User duplication for email: {} (client_id: {})".format( + email, + client_id, + ), + ) + raise DuplicateAuthUser() + + row = self._repository.add( + session, + email=email, + client_id=client_id, + user_type=user_type, + **kwargs, + ) + logger.info( + "New auth user (id: {}) created by email (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + @provide_global_contextual_session + def add_by_client_id( + self, + session, + client_id, + user_type=EntityType.FORTMATIC.value, + provenance=None, + global_auth_user_id=None, + is_verified=False, + ): + row = self._repository.add( + session, + client_id=client_id, + user_type=user_type, + provenance=provenance, + global_auth_user_id=global_auth_user_id, + date_verified=datetime.utcnow() if is_verified else None, + ) + logger.info( + "New auth user (id: {}) created by (client_id: {})".format(row.id, client_id), + ) + + return row + + @provide_global_contextual_session + def get_by_session_token( + self, + session, + session_token, + ): + return one( + self._repository.get_by( + session, + filters=[auth_user_model.current_session_token == session_token], + limit=1, + ) + ) + + @provide_global_contextual_session + def get_by_active_identifier_and_client_id( + self, + session, + identifier_field, + identifier_value, + client_id, + user_type, + for_update=False, + ) -> t.Optional[auth_user_model]: + """There should only be one active identifier where all the parameters match for a given client ID. In the case of multiple results, the subsequent entries / "dupes" will be marked as inactive.""" + filters = [ + identifier_field == identifier_value, + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + ] + + results = self._repository.get_by( + session, + filters=filters, + order_by_clause=auth_user_model.id.asc(), + for_update=for_update, + ) + + if not results: + return None + + original, *duplicates = results + + if duplicates: + signals.auth_user_duplicate.send( + current_app, + original_auth_user_id=original.id, + duplicate_auth_user_ids=[dupe.id for dupe in duplicates], + ) + + return original + + @provide_global_contextual_session + def get_by_email_and_client_id( + self, + session, + email, + client_id, + user_type=EntityType.FORTMATIC.value, + for_update=False, + ): + return self.get_by_active_identifier_and_client_id( + session=session, + identifier_field=auth_user_model.email, + identifier_value=email, + client_id=client_id, + user_type=user_type, + for_update=for_update, + ) + + @provide_global_contextual_session + def exist_by_email_and_client_id( + self, + session, + email, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + return bool( + self._repository.exist( + session, + filters=[ + auth_user_model.email == email, + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + ], + ), + ) + + @provide_global_contextual_session + def get_by_id( + self, session, model_id, join_list=None, for_update=False + ) -> t.Optional[auth_user_model]: + return self._repository.get_by_id( + session, + model_id, + join_list=join_list, + for_update=for_update, + ) + + @provide_global_contextual_session + def update_by_id(self, session, auth_user_id, **kwargs): + modified_user = self._repository.update(session, auth_user_id, **kwargs) + + if modified_user is None: + raise AuthUserDoesNotExist() + + return modified_user + + @provide_global_contextual_session + def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): + query = ( + session.query(auth_user_model) + .filter( + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + auth_user_model.date_verified.is_not(None), + ) + .statement.with_only_columns(sa.func.count()) + .order_by(None) + ) + + return session.execute(query).scalar() + + @provide_global_contextual_session + def get_by_client_id_and_user_type( + self, + session, + client_id, + user_type, + offset=None, + limit=None, + ): + return self.get_by_client_ids_and_user_type( + session, + [client_id], + user_type, + offset=offset, + limit=limit, + ) + + @provide_global_contextual_session + def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): + yield from self._repository.yield_by_chunk( + session, + chunk_size, + filters=filters, + join_list=join_list, + ) + + @provide_global_contextual_session + def get_by_emails_and_client_id( + self, + session, + email_ids, + client_id, + ): + return self._repository.get_by( + session, + filters=[ + auth_user_model.email.in_(email_ids), + auth_user_model.client_id == client_id, + ], + ) + + @provide_global_contextual_session + def get_by_email( + self, + session, + email: str, + join_list=None, + filters=None, + for_update: bool = False, + ) -> t.List[auth_user_model]: + filters = filters or [] + combined_filters = filters + [auth_user_model.email == email] + + return self._repository.get_by( + session, + filters=combined_filters, + for_update=for_update, + join_list=join_list, + ) + + +class AuthWallet(LogicComponent[auth_wallet_model, ObjectID, sa.orm.Session]): + model = auth_wallet_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) + + @provide_global_contextual_session + def add( + self, + session, + public_address, + encrypted_private_address, + wallet_type, + network, + management_type=None, + auth_user_id=None, + ): + new_row = self._repository.add( + session, + auth_user_id=auth_user_id, + public_address=public_address, + encrypted_private_address=encrypted_private_address, + wallet_type=wallet_type, + management_type=management_type, + network=network, + ) + + return new_row + + @provide_global_contextual_session + def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): + return self._repository.get_by_id( + session, + model_id, + allow_inactive=allow_inactive, + join_list=join_list, + ) + + @provide_global_contextual_session + def get_by_public_address(self, session, public_address, network=None, is_active=True): + filters = [ + auth_wallet_model.public_address == public_address, + ] + + if network: + filters.append(auth_wallet_model.network == network) + + row = self._repository.get_by(session, filters=filters, allow_inactive=not is_active) + + if not row: + return None + + return one(row) + + @provide_global_contextual_session + def get_by_auth_user_id( + self, + session, + auth_user_id, + network=None, + wallet_type=None, + is_active=True, + join_list=None, + ): + filters = [ + auth_wallet_model.auth_user_id == auth_user_id, + ] + + if network: + filters.append(auth_wallet_model.network == network) + + if wallet_type: + filters.append(auth_wallet_model.wallet_type == wallet_type) + + rows = self._repository.get_by( + session, filters=filters, join_list=join_list, allow_inactive=not is_active + ) + + if not rows: + return [] + + return rows + + @provide_global_contextual_session + def update_by_id(self, session, model_id, **kwargs): + self._repository.update(session, model_id, **kwargs) diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py new file mode 100644 index 0000000..8350c68 --- /dev/null +++ b/src/quart_sqlalchemy/sim/main.py @@ -0,0 +1,10 @@ +from quart_sqlalchemy.sim import commands +from quart_sqlalchemy.sim.app import create_app + + +app = create_app() + +commands.attach(app) + +if __name__ == "__main__": + app.run(port=8081) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py new file mode 100644 index 0000000..9df6fe9 --- /dev/null +++ b/src/quart_sqlalchemy/sim/model.py @@ -0,0 +1,165 @@ +import secrets +import typing as t +from datetime import datetime +from enum import Enum +from enum import IntEnum + +import sqlalchemy +import sqlalchemy.orm +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped + +from quart_sqlalchemy.model import SoftDeleteMixin +from quart_sqlalchemy.model import TimestampMixin +from quart_sqlalchemy.sim.db import MyBase +from quart_sqlalchemy.sim.util import ObjectID + + +sa = sqlalchemy + + +class StrEnum(str, Enum): + def __str__(self) -> str: + return str.__str__(self) + + +class ConnectInteropStatus(StrEnum): + ENABLED = "ENABLED" + DISABLED = "DISABLED" + + +class Provenance(Enum): + LINK = 1 + OAUTH = 2 + WEBAUTHN = 3 + SMS = 4 + IDENTIFIER = 5 + FEDERATED = 6 + + +class EntityType(Enum): + FORTMATIC = 1 + MAGIC = 2 + CONNECT = 3 + + +class WalletManagementType(IntEnum): + UNDELEGATED = 1 + DELEGATED = 2 + + +class WalletType(StrEnum): + ETH = "ETH" + HARMONY = "HARMONY" + ICON = "ICON" + FLOW = "FLOW" + TEZOS = "TEZOS" + ZILLIQA = "ZILLIQA" + POLKADOT = "POLKADOT" + SOLANA = "SOLANA" + AVAX = "AVAX" + ALGOD = "ALGOD" + COSMOS = "COSMOS" + CELO = "CELO" + BITCOIN = "BITCOIN" + NEAR = "NEAR" + HELIUM = "HELIUM" + CONFLUX = "CONFLUX" + TERRA = "TERRA" + TAQUITO = "TAQUITO" + ED = "ED" + HEDERA = "HEDERA" + + +class MagicClient(MyBase, SoftDeleteMixin, TimestampMixin): + __tablename__ = "magic_client" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + app_name: Mapped[str] = sa.orm.mapped_column(default="my new app") + rate_limit_tier: Mapped[t.Optional[str]] + connect_interop: Mapped[t.Optional[ConnectInteropStatus]] + is_signing_modal_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) + global_audience_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) + + public_api_key: Mapped[str] = sa.orm.mapped_column(default=secrets.token_hex) + secret_api_key: Mapped[str] = sa.orm.mapped_column(default=secrets.token_hex) + + auth_users: Mapped[t.List["AuthUser"]] = sa.orm.relationship( + back_populates="magic_client", + primaryjoin="and_(foreign(AuthUser.client_id) == MagicClient.id, AuthUser.user_type != 1)", + ) + + +class AuthUser(MyBase, SoftDeleteMixin, TimestampMixin): + __tablename__ = "auth_user" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + email: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + phone_number: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + user_type: Mapped[int] = sa.orm.mapped_column(default=EntityType.FORTMATIC.value) + date_verified: Mapped[t.Optional[datetime]] + provenance: Mapped[t.Optional[Provenance]] + is_admin: Mapped[bool] = sa.orm.mapped_column(default=False) + client_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("magic_client.id")) + linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] = sa.orm.mapped_column( + sa.ForeignKey("auth_user.id"), default=None + ) + global_auth_user_id: Mapped[t.Optional[ObjectID]] + + delegated_user_id: Mapped[t.Optional[str]] + delegated_identity_pool_id: Mapped[t.Optional[str]] + + current_session_token: Mapped[t.Optional[str]] + + magic_client: Mapped[MagicClient] = sa.orm.relationship( + back_populates="auth_users", + uselist=False, + ) + linked_primary_auth_user = sa.orm.relationship( + "AuthUser", + remote_side=[id], + lazy="joined", + join_depth=1, + uselist=False, + ) + wallets: Mapped[t.List["AuthWallet"]] = sa.orm.relationship(back_populates="auth_user") + + @hybrid_property + def is_email_verified(self): + return self.email is not None and self.date_verified is not None + + @hybrid_property + def is_waiting_on_email_verification(self): + return self.email is not None and self.date_verified is None + + @hybrid_property + def is_new_signup(self): + return self.date_verified is None + + @hybrid_property + def has_linked_primary_auth_user(self): + return bool(self.linked_primary_auth_user_id) + + @hybrid_property + def is_magic_connect_user(self): + return self.global_auth_user_id is not None and self.user_type == EntityType.CONNECT.value + + +class AuthWallet(MyBase, SoftDeleteMixin, TimestampMixin): + __tablename__ = "auth_wallet" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + auth_user_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("auth_user.id")) + wallet_type: Mapped[str] = sa.orm.mapped_column(default=WalletType.ETH.value) + management_type: Mapped[int] = sa.orm.mapped_column( + default=WalletManagementType.UNDELEGATED.value + ) + public_address: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + encrypted_private_address: Mapped[t.Optional[str]] + network: Mapped[str] + is_exported: Mapped[bool] = sa.orm.mapped_column(default=False) + + auth_user: Mapped[AuthUser] = sa.orm.relationship( + back_populates="wallets", + uselist=False, + ) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py new file mode 100644 index 0000000..618af04 --- /dev/null +++ b/src/quart_sqlalchemy/sim/repo.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +import operator +import typing as t +from abc import ABCMeta + +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +import sqlalchemy.sql + +from quart_sqlalchemy.sim.builder import StatementBuilder +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT + + +sa = sqlalchemy + + +class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): + """A repository interface.""" + + model: t.Type[EntityT] + identity: t.Type[EntityIdT] + + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + + +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): + """A repository interface for bulk operations. + + Note: this interface circumvents ORM internals, breaking commonly expected behavior in order + to gain performance benefits. Only use this class whenever absolutely necessary. + """ + + model: t.Type[EntityT] + identity: t.Type[EntityIdT] + + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + + +class SQLAlchemyRepository( + AbstractRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] +): + """A repository that uses SQLAlchemy to persist data. + + The biggest change with this repository is that for methods returning multiple results, we + return the sa.ScalarResult so that the caller has maximum flexibility in how it's consumed. + + As a result, when calling a method such as get_by, you then need to decide how to fetch the + result. + + Methods of fetching results: + - .all() to return a list of results + - .first() to return the first result + - .one() to return the first result or raise an exception if there are no results + - .one_or_none() to return the first result or None if there are no results + - .partitions(n) to return a results as a list of n-sized sublists + + Additionally, there are methods for transforming the results prior to fetching. + + Methods of transforming results: + - .unique() to apply unique filtering to the result + + """ + + builder: StatementBuilder + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.builder = StatementBuilder(self.model) + + def _build_execution_options( + self, + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + include_inactive: bool = False, + yield_by_chunk: bool = False, + ): + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + if yield_by_chunk: + execution_options.setdefault("yield_per", yield_by_chunk) + return execution_options + + def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: + """Insert a new model into the database.""" + new = self.model(**values) + + session.add(new) + session.flush() + session.refresh(new) + + return new + + def update( + self, session: sa.orm.Session, id_: EntityIdT, values: t.Dict[str, t.Any] + ) -> EntityT: + """Update existing model with new values.""" + obj = session.get(self.model, id_) + if obj is None: + raise ValueError(f"Object with id {id_} not found") + for field, value in values.items(): + if getattr(obj, field) != value: + setattr(obj, field, value) + + session.flush() + session.refresh(obj) + + return obj + + def merge( + self, + session: sa.orm.Session, + id_: EntityIdT, + values: t.Dict[str, t.Any], + for_update: bool = False, + ) -> EntityT: + """Merge model in session/db having id_ with values.""" + session.get(self.model, id_) + values.update(id=id_) + + merged = session.merge(self.model(**values)) + session.flush() + session.refresh(merged, with_for_update=for_update) # type: ignore + + return merged + + def get( + self, + session: sa.orm.Session, + id_: EntityIdT, + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + for_update: bool = False, + include_inactive: bool = False, + ) -> t.Optional[EntityT]: + """Get object identified by id_ from the database. + + Note: It's a common misconception that session.get(Model, id) is akin to a shortcut for + a select(Model).where(Model.id == id) like statement. However this is not the case. + + Session.get is actually used for looking an object up in the sessions identity map. When + present it will be returned directly, when not, a database lookup will be performed. + + For use cases where this is what you actually want, you can still access the original get + method on session. For most uses cases, this behavior can introduce non-determinism + and because of that this method performs lookup using a select statement. Additionally, + to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called + on the result before returning. + """ + selectables = (self.model,) + + execution_options = self._build_execution_options( + execution_options, include_inactive=include_inactive + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) + + statement = self.builder.select( + selectables, # type: ignore + conditions=[self.model.id == id_], + options=options, + limit=1, + for_update=for_update, + ) + + return session.scalars(statement, execution_options=execution_options).one_or_none() + + def get_by_field( + self, + session: sa.orm.Session, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + selectables = (self.model,) + + execution_options = self._build_execution_options( + execution_options, include_inactive=include_inactive + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) + + if isinstance(field, str): + field = getattr(self.model, field) + + conditions = [t.cast(ColumnExpr, op(field, value))] + + statement = self.builder.select( + selectables, # type: ignore + conditions=conditions, + order_by=order_by, + options=options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + return session.scalars(statement, execution_options=execution_options) + + def select( + self, + session: sa.orm.Session, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + yield_by_chunk: t.Optional[int] = None, + ) -> t.Union[sa.ScalarResult[EntityT], t.Iterator[t.Sequence[EntityT]]]: + """Select from the database. + + Note: yield_by_chunk is not compatible with the subquery and joined loader strategies, use selectinload for eager loading. + """ + selectables = selectables or (self.model,) # type: ignore + + execution_options = self._build_execution_options( + execution_options, + include_inactive=include_inactive, + yield_by_chunk=yield_by_chunk, + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) + # if yield_by_chunk: + # execution_options.setdefault("yield_per", yield_by_chunk) + + statement = self.builder.select( + selectables, + conditions=conditions, + group_by=group_by, + order_by=order_by, + options=options, + execution_options=execution_options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + results = session.scalars(statement) + if yield_by_chunk: + results = results.partitions() + return results + + def delete( + self, session: sa.orm.Session, id_: EntityIdT, include_inactive: bool = False + ) -> None: + entity = self.get(session, id_, include_inactive=include_inactive) + if not entity: + raise RuntimeError(f"Entity with id {id_} not found.") + + session.delete(entity) + session.flush() + + def deactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: + return self.update(session, id_, dict(is_active=False)) + + def reactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: + return self.update(session, id_, dict(is_active=False)) + + def exists( + self, + session: sa.orm.Session, + conditions: t.Sequence[ColumnExpr] = (), + for_update: bool = False, + include_inactive: bool = False, + ) -> bool: + """Return whether an object matching conditions exists. + + Note: This performs better than simply trying to select an object since there is no + overhead in sending the selected object and deserializing it. + """ + selectables = (sa.sql.literal(True),) + + execution_options = self._build_execution_options(None, include_inactive=include_inactive) + # execution_options = {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) + + statement = self.builder.select( + selectables, + conditions=conditions, + limit=1, + for_update=for_update, + ) + + result = session.execute(statement, execution_options=execution_options).scalar() + + return bool(result) + + +class SQLAlchemyBulkRepository( + AbstractBulkRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] +): + builder: StatementBuilder + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.builder = StatementBuilder(None) + + def bulk_insert( + self, + session: SessionT, + values: t.Sequence[t.Dict[str, t.Any]] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_insert(self.model, values) + return session.execute(statement, execution_options=execution_options or {}) + + def bulk_update( + self, + session: SessionT, + conditions: t.Sequence[ColumnExpr] = (), + values: t.Optional[t.Dict[str, t.Any]] = None, + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_update(self.model, conditions, values) + return session.execute(statement, execution_options=execution_options or {}) + + def bulk_delete( + self, + session: SessionT, + conditions: t.Sequence[ColumnExpr] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_delete(self.model, conditions) + return session.execute(statement, execution_options=execution_options or {}) diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py new file mode 100644 index 0000000..8e0a43d --- /dev/null +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -0,0 +1,333 @@ +import typing as t + +import sqlalchemy +import sqlalchemy.orm +from pydantic import BaseModel + +from quart_sqlalchemy.sim.repo import SQLAlchemyRepository +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT + + +sa = sqlalchemy + + +class BaseModelSchema(BaseModel): + class Config: + from_orm = True + + +class BaseCreateSchema(BaseModelSchema): + pass + + +class BaseUpdateSchema(BaseModelSchema): + pass + + +ModelSchemaT = t.TypeVar("ModelSchemaT", bound=BaseModelSchema) +CreateSchemaT = t.TypeVar("CreateSchemaT", bound=BaseCreateSchema) +UpdateSchemaT = t.TypeVar("UpdateSchemaT", bound=BaseUpdateSchema) + + +class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT, SessionT]): + model: t.Type[EntityT] + identity: t.Type[EntityIdT] + + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + self._adapted = SQLAlchemyRepository(model, identity) + + def get_by( + self, + session: SessionT, + filters=None, + allow_inactive=False, + join_list=None, + order_by_clause=None, + for_update=False, + offset=None, + limit=None, + ) -> t.Sequence[EntityT]: + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + join_list = join_list or () + + if order_by_clause is not None: + order_by_clause = (order_by_clause,) + else: + order_by_clause = () + + return self._adapted.select( + session, + conditions=filters, + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + order_by=order_by_clause, + offset=offset, + limit=limit, + include_inactive=allow_inactive, + ).all() + + def get_by_id( + self, + session: SessionT, + model_id=None, + allow_inactive=False, + join_list=None, + for_update=False, + ) -> t.Optional[EntityT]: + if model_id is None: + raise ValueError("model_id is required") + join_list = join_list or () + return self._adapted.get( + session, + id_=model_id, + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + include_inactive=allow_inactive, + ) + + def one( + self, + session: SessionT, + filters=None, + join_list=None, + for_update=False, + include_inactive=False, + ) -> EntityT: + filters = filters or () + join_list = join_list or () + return self._adapted.select( + session, + conditions=filters, + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + include_inactive=include_inactive, + ).one() + + def count_by( + self, + session: SessionT, + filters=None, + group_by=None, + distinct_column=None, + ): + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + group_by = group_by or () + + if distinct_column: + selectables = [sa.label("count", sa.func.count(sa.func.distinct(distinct_column)))] + else: + selectables = [sa.label("count", sa.func.count(self.model.id))] + + for group in group_by: + selectables.append(group.expression) + + result = self._adapted.select(session, selectables, conditions=filters, group_by=group_by) + + return result.all() + + def add(self, session: SessionT, **kwargs) -> EntityT: + return self._adapted.insert(session, kwargs) + + def update(self, session: SessionT, model_id=None, **kwargs) -> EntityT: + return self._adapted.update(session, id_=model_id, values=kwargs) + + def update_by(self, session: SessionT, filters=None, **kwargs) -> EntityT: + if not filters: + raise ValueError("Full table scans are prohibited. Please provide filters") + + row = self._adapted.select(session, conditions=filters, limit=2).one() + return self._adapted.update(session, id_=row.id, values=kwargs) + + def delete_by_id(self, session: SessionT, model_id=None) -> None: + self._adapted.delete(session, id_=model_id, include_inactive=True) + + def delete_one_by(self, session: SessionT, filters=None, optional=False) -> None: + filters = filters or () + result = self._adapted.select(session, conditions=filters, limit=1) + + if optional: + row = result.one_or_none() + if row is None: + return + else: + row = result.one() + + self._adapted.delete(session, id_=row.id) + + def exist(self, session: SessionT, filters=None, allow_inactive=False) -> bool: + filters = filters or () + return self._adapted.exists( + session, + conditions=filters, + include_inactive=allow_inactive, + ) + + def yield_by_chunk( + self, session: SessionT, chunk_size=100, join_list=None, filters=None, allow_inactive=False + ): + filters = filters or () + join_list = join_list or () + results = self._adapted.select( + session, + conditions=filters, + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], + include_inactive=allow_inactive, + yield_by_chunk=chunk_size, + ) + for result in results: + yield result + + +class PydanticScalarResult(sa.ScalarResult, t.Generic[ModelSchemaT]): + pydantic_schema: t.Type[ModelSchemaT] + + def __init__(self, scalar_result: t.Any, pydantic_schema: t.Type[ModelSchemaT]): + for attribute in scalar_result.__slots__: + setattr(self, attribute, getattr(scalar_result, attribute)) + self.pydantic_schema = pydantic_schema + + def _translate_many(self, rows): + return [self.pydantic_schema.from_orm(row) for row in rows] + + def _translate_one(self, row): + if row is None: + return + return self.pydantic_schema.from_orm(row) + + def all(self): + return self._translate_many(super().all()) + + def fetchall(self): + return self._translate_many(super().fetchall()) + + def fetchmany(self, *args, **kwargs): + return self._translate_many(super().fetchmany(*args, **kwargs)) + + def first(self): + return self._translate_one(super().first()) + + def one(self): + return self._translate_one(super().one()) + + def one_or_none(self): + return self._translate_one(super().one_or_none()) + + def partitions(self, *args, **kwargs): + for partition in super().partitions(*args, **kwargs): + yield self._translate_many(partition) + + +class PydanticRepository( + SQLAlchemyRepository[EntityT, EntityIdT, SessionT], + t.Generic[EntityT, EntityIdT, SessionT, ModelSchemaT, CreateSchemaT, UpdateSchemaT], +): + model_schema: t.Type[ModelSchemaT] + + def insert( + self, + session: SessionT, + create_schema: CreateSchemaT, + sqla_model=False, + ): + create_data = create_schema.dict() + result = super().insert(session, create_data) + + if sqla_model: + return result + return self.model_schema.from_orm(result) + + def update( + self, + session: SessionT, + id_: EntityIdT, + update_schema: UpdateSchemaT, + sqla_model=False, + ): + existing = session.query(self.model).get(id_) + if existing is None: + raise ValueError("Model not found") + + update_data = update_schema.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(existing, key, value) + + session.add(existing) + session.flush() + session.refresh(existing) + + if sqla_model: + return existing + return self.model_schema.from_orm(existing) + + def get( + self, + session: SessionT, + id_: EntityIdT, + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + for_update: bool = False, + include_inactive: bool = False, + sqla_model: bool = False, + ): + row = super().get( + session, + id_, + options, + execution_options, + for_update, + include_inactive, + ) + if row is None: + return + + if sqla_model: + return row + return self.model_schema.from_orm(row) + + def select( + self, + session: SessionT, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + yield_by_chunk: t.Optional[int] = None, + sqla_model: bool = False, + ): + result = super().select( + session, + selectables, + conditions, + group_by, + order_by, + options, + execution_options, + offset, + limit, + distinct, + for_update, + include_inactive, + yield_by_chunk, + ) + + if sqla_model: + return result + return PydanticScalarResult[self.model_schema](result, self.model_schema) diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py new file mode 100644 index 0000000..e679d61 --- /dev/null +++ b/src/quart_sqlalchemy/sim/schema.py @@ -0,0 +1,99 @@ +import typing as t +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel +from pydantic import Field +from pydantic import validator +from pydantic.generics import GenericModel + +from .model import ConnectInteropStatus +from .model import EntityType +from .model import Provenance +from .model import WalletManagementType +from .model import WalletType +from .util import ObjectID + + +DataT = t.TypeVar("DataT") + +json_encoders = { + ObjectID: lambda v: v.encode(), + datetime: lambda dt: int(dt.timestamp()), + Enum: lambda e: e.value, +} + + +class BaseSchema(BaseModel): + class Config: + arbitrary_types_allowed = True + json_encoders = dict(json_encoders) + orm_mode = True + + @classmethod + def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: + if hasattr(v, "__serialize__"): + return v.__serialize__() + for type_, converter in cls.__config__.json_encoders.items(): + if isinstance(v, type_): + return converter(v) + + return super()._get_value(v, *args, **kwargs) + + +class ResponseWrapper(GenericModel, t.Generic[DataT]): + """Generic response wrapper""" + + class Config: + arbitrary_types_allowed = True + json_encoders = dict(json_encoders) + + error_code: str = "" + status: str = "" + message: str = "" + data: DataT = Field(default_factory=dict) + + @validator("status") + def set_status_by_error_code(cls, v, values): + error_code = values.get("error_code") + if error_code: + return "failed" + return "ok" + + +class MagicClientSchema(BaseSchema): + id: ObjectID + app_name: str + rate_limit_tier: t.Optional[str] = None + connect_interop: t.Optional[ConnectInteropStatus] = None + is_signing_modal_enabled: bool + global_audience_enabled: bool + public_api_key: str + secret_api_key: str + + +class AuthUserSchema(BaseSchema): + id: ObjectID + client_id: ObjectID + email: str + phone_number: t.Optional[str] = None + user_type: EntityType = EntityType.MAGIC + provenance: t.Optional[Provenance] = None + date_verified: t.Optional[datetime] = None + is_admin: bool = False + linked_primary_auth_user_id: t.Optional[ObjectID] = None + global_auth_user_id: t.Optional[ObjectID] = None + delegated_user_id: t.Optional[str] = None + delegated_identity_pool_id: t.Optional[str] = None + current_session_token: t.Optional[str] = None + + +class AuthWalletSchema(BaseSchema): + id: ObjectID + auth_user_id: ObjectID + wallet_type: WalletType + management_type: WalletManagementType + public_address: str + encrypted_private_address: str + network: str + is_exported: bool diff --git a/src/quart_sqlalchemy/sim/signals.py b/src/quart_sqlalchemy/sim/signals.py new file mode 100644 index 0000000..510be83 --- /dev/null +++ b/src/quart_sqlalchemy/sim/signals.py @@ -0,0 +1,34 @@ +from blinker import Namespace + + +# Synchronous signals +_sync = Namespace() + +auth_user_duplicate = _sync.signal( + "auth_user_duplicate", + doc="""Called on discovery of at least one duplicate auth user. + + Handlers should have the following signature: + def handler( + current_app: Quart, + original_auth_user_id: ObjectID, + duplicate_auth_user_ids: List[ObjectID], + session: sa.orm.Session, + ) -> None: + ... + """, +) + +keys_rolled = _sync.signal( + "keys_rolled", + doc="""Called after api keys are rolled. + + Handlers should have the following signature: + def handler( + app: Quart, + deactivated_keys: Dict[str, Any], + redis_client: Redis, + ) -> None: + ... + """, +) diff --git a/src/quart_sqlalchemy/sim/testing.py b/src/quart_sqlalchemy/sim/testing.py new file mode 100644 index 0000000..fd84547 --- /dev/null +++ b/src/quart_sqlalchemy/sim/testing.py @@ -0,0 +1,16 @@ +from contextlib import contextmanager + +from quart import g +from quart import Quart +from quart import signals + +from quart_sqlalchemy import Bind + + +@contextmanager +def global_bind(app: Quart, bind: Bind): + def handler(sender, **kwargs): + g.bind = bind + + with signals.appcontext_pushed.connected_to(handler, app): + yield diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py new file mode 100644 index 0000000..d4d33d1 --- /dev/null +++ b/src/quart_sqlalchemy/sim/util.py @@ -0,0 +1,100 @@ +import logging +import numbers + +from hashids import Hashids + + +logger = logging.getLogger(__name__) + + +class CryptographyError(Exception): + pass + + +class DecryptionError(CryptographyError): + pass + + +def one(input_list): + if len(input_list) != 1: + raise ValueError(f"Expected a list of length 1, got {len(input_list)}") + return input_list[0] + + +class ObjectID: + hashids = Hashids(min_length=12) + + def __init__(self, input_value): + if input_value is None: + raise ValueError("ObjectID cannot be None") + elif isinstance(input_value, ObjectID): + self._source_id = input_value._decoded_id + elif isinstance(input_value, str): + self._source_id = self._decode(input_value) + elif isinstance(input_value, numbers.Number): + try: + input_value = int(input_value) + except (ValueError, TypeError): + pass + + self._source_id = input_value + self._encode() + + @property + def _encoded_id(self): + return self._encode() + + @property + def _decoded_id(self): + return self._source_id + + def __eq__(self, other): + if isinstance(other, ObjectID): + return self._decoded_id == other._decoded_id and self._encoded_id == other._encoded_id + elif isinstance(other, int): + return self._decoded_id == other + elif isinstance(other, str): + return self._encoded_id == other + else: + return False + + def __lt__(self, other): + if isinstance(other, ObjectID): + return self._decoded_id < other._decoded_id + return False + + def __hash__(self): + return hash(tuple([self._encoded_id, self._decoded_id])) + + def __str__(self): + return "{encoded_id}".format(encoded_id=self._encoded_id) + + def __int__(self): + return self._decoded_id + + def __repr__(self): + return f"{type(self).__name__}({self._decoded_id})" + + def __json__(self): + return self.__str__() + + def _encode(self): + if isinstance(self._source_id, int): + return self.hashids.encode(self._source_id) + else: + return self._source_id + + def encode(self): + return self._encoded_id + + def _decode(self, value): + if isinstance(value, int): + return value + else: + return self.hashids.decode(value)[0] + + def decode(self): + return self._decoded_id + + def decode_str(self): + return str(self._decoded_id) diff --git a/src/quart_sqlalchemy/sim/views/__init__.py b/src/quart_sqlalchemy/sim/views/__init__.py new file mode 100644 index 0000000..656f265 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/__init__.py @@ -0,0 +1,18 @@ +from quart import Blueprint +from quart import g + +from .auth_user import api as auth_user_api +from .auth_wallet import api as auth_wallet_api +from .magic_client import api as magic_client_api + + +api = Blueprint("api", __name__, url_prefix="/api") + +api.register_blueprint(auth_user_api) +api.register_blueprint(auth_wallet_api) +api.register_blueprint(magic_client_api) + + +@api.before_request +def set_feature_owner(): + g.request_feature_owner = "auth-team" diff --git a/src/quart_sqlalchemy/sim/views/auth_user.py b/src/quart_sqlalchemy/sim/views/auth_user.py new file mode 100644 index 0000000..cf1d366 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/auth_user.py @@ -0,0 +1,86 @@ +import logging +import typing as t + +import sqlalchemy.orm +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide + +from quart_sqlalchemy.framework import QuartSQLAlchemy + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..container import Container +from ..handle import AuthUserHandler +from ..model import EntityType +from ..schema import AuthUserSchema +from ..schema import BaseSchema +from ..schema import ResponseWrapper +from .util import APIBlueprint + + +sa = sqlalchemy + +logger = logging.getLogger(__name__) +api = APIBlueprint("auth_user", __name__, url_prefix="/auth_user") + + +class CreateAuthUserRequest(BaseSchema): + email: str + + +class CreateAuthUserResponse(BaseSchema): + auth_user: AuthUserSchema + + +@api.get( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + "session-token-bearer": [], + } + ], + ), +) +@inject +def get_auth_user( + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide["request_credentials"], +) -> ResponseWrapper[AuthUserSchema]: + with db.bind.Session() as session: + auth_user = auth_user_handler.get_by_session_token(session, credentials.current_user.value) + + return ResponseWrapper[AuthUserSchema](data=AuthUserSchema.from_orm(auth_user)) + + +@api.post( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def create_auth_user( + data: CreateAuthUserRequest, + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide[Container.request_credentials], +) -> ResponseWrapper[CreateAuthUserResponse]: + with db.bind.Session() as session: + with session.begin(): + client = auth_user_handler.create_verified_user( + session, + email=data.email, + client_id=credentials.current_client.subject.id, + user_type=EntityType.MAGIC.value, + ) + + return ResponseWrapper[CreateAuthUserResponse]( + data=dict(auth_user=AuthUserSchema.from_orm(client)) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py new file mode 100644 index 0000000..9275999 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -0,0 +1,90 @@ +import logging +import typing as t + +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide +from quart import g + +from quart_sqlalchemy.framework import QuartSQLAlchemy + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..handle import AuthWalletHandler +from ..model import WalletManagementType +from ..model import WalletType +from ..schema import BaseSchema +from ..schema import ResponseWrapper +from ..util import ObjectID +from ..web3 import Web3 +from .util import APIBlueprint + + +logger = logging.getLogger(__name__) +api = APIBlueprint("auth_wallet", __name__, url_prefix="/auth_wallet") + + +@api.before_request +def set_feature_owner(): + g.request_feature_owner = "wallet-team" + + +class WalletSyncRequest(BaseSchema): + public_address: str + encrypted_private_address: str + wallet_type: str + hd_path: t.Optional[str] = None + encrypted_seed_phrase: t.Optional[str] = None + + +class WalletSyncResponse(BaseSchema): + wallet_id: ObjectID + auth_user_id: ObjectID + wallet_type: WalletType + public_address: str + encrypted_private_address: str + + +@api.post( + "/sync", + authorizer=authorized_request( + [ + # We use the OpenAPI security scheme metadata to know which kind of authorization to enforce. + # + # Together in the same dict implies logical AND requirement so both public-api-key and + # session-token will be enforced + { + "public-api-key": [], + "session-token-bearer": [], + } + ], + ), +) +@inject +def sync( + data: WalletSyncRequest, + auth_wallet_handler: AuthWalletHandler = Provide["AuthWalletHandler"], + web3: Web3 = Provide["web3"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide["request_credentials"], +) -> ResponseWrapper[WalletSyncResponse]: + with db.bind.Session() as session: + with session.begin(): + wallet = auth_wallet_handler.sync_auth_wallet( + session, + credentials.current_user.subject.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + network=web3.network, + wallet_type=data.wallet_type, + ) + + return ResponseWrapper[WalletSyncResponse]( + data=dict( + wallet_id=wallet.id, + auth_user_id=wallet.auth_user_id, + wallet_type=wallet.wallet_type, + public_address=wallet.public_address, + encrypted_private_address=wallet.encrypted_private_address, + ) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/views/magic_client.py b/src/quart_sqlalchemy/sim/views/magic_client.py new file mode 100644 index 0000000..dc506a2 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/magic_client.py @@ -0,0 +1,90 @@ +import logging +import typing as t + +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide + +from quart_sqlalchemy.framework import QuartSQLAlchemy + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..handle import MagicClientHandler +from ..model import ConnectInteropStatus +from ..schema import BaseSchema +from ..schema import MagicClientSchema +from ..schema import ResponseWrapper +from .util import APIBlueprint + + +logger = logging.getLogger(__name__) +api = APIBlueprint("magic_client", __name__, url_prefix="/magic_client") + + +class CreateMagicClientRequest(BaseSchema): + app_name: str + rate_limit_tier: t.Optional[str] = None + connect_interop: t.Optional[ConnectInteropStatus] = None + is_signing_modal_enabled: bool = False + global_audience_enabled: bool = False + + +class CreateMagicClientResponse(BaseSchema): + magic_client: MagicClientSchema + + +@api.post( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def create_magic_client( + data: CreateMagicClientRequest, + magic_client_handler: MagicClientHandler = Provide["MagicClientHandler"], + db: QuartSQLAlchemy = Provide["db"], +) -> ResponseWrapper[CreateMagicClientResponse]: + with db.bind.Session() as session: + with session.begin(): + client = magic_client_handler.add( + session, + app_name=data.app_name, + rate_limit_tier=data.rate_limit_tier, + connect_interop=data.connect_interop, + is_signing_modal_enabled=data.is_signing_modal_enabled, + global_audience_enabled=data.global_audience_enabled, + ) + + return ResponseWrapper[CreateMagicClientResponse]( + data=dict(magic_client=MagicClientSchema.from_orm(client)) # type: ignore + ) + + +@api.get( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def get_magic_client( + magic_client_handler: MagicClientHandler = Provide["MagicClientHandler"], + credentials: RequestCredentials = Provide["request_credentials"], + db: QuartSQLAlchemy = Provide["db"], +) -> ResponseWrapper[MagicClientSchema]: + with db.bind.Session() as session: + client = magic_client_handler.get_by_public_api_key( + session, credentials.current_client.value + ) + + return ResponseWrapper[MagicClientSchema]( + data=MagicClientSchema.from_orm(client) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/views/util/__init__.py b/src/quart_sqlalchemy/sim/views/util/__init__.py new file mode 100644 index 0000000..e677b31 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/__init__.py @@ -0,0 +1,12 @@ +from .blueprint import APIBlueprint +from .decorator import inject_request +from .decorator import validate_request +from .decorator import validate_response + + +__all__ = ( + "APIBlueprint", + "inject_request", + "validate_request", + "validate_response", +) diff --git a/src/quart_sqlalchemy/sim/views/util/blueprint.py b/src/quart_sqlalchemy/sim/views/util/blueprint.py new file mode 100644 index 0000000..35927ec --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/blueprint.py @@ -0,0 +1,102 @@ +import inspect +import typing as t + +from quart import Blueprint +from quart import Request +from quart_schema.validation import validate_headers +from quart_schema.validation import validate_querystring + +from ...schema import BaseSchema +from .decorator import inject_request +from .decorator import validate_request +from .decorator import validate_response + + +class APIBlueprint(Blueprint): + def _endpoint( + self, + uri: str, + methods: t.Optional[t.Sequence[str]] = ("GET",), + authorizer: t.Optional[t.Callable] = None, + **route_kwargs, + ): + def decorator(func): + sig = inspect.signature(func) + + param_annotation_map = { + name: param.annotation for name, param in sig.parameters.items() + } + has_request_schema = "data" in sig.parameters and issubclass( + param_annotation_map["data"], + BaseSchema, + ) + has_query_schema = "query_args" in sig.parameters and issubclass( + param_annotation_map["query_args"], + BaseSchema, + ) + has_headers_schema = "headers" in sig.parameters and issubclass( + param_annotation_map["headers"], + BaseSchema, + ) + + has_response_schema = isinstance(sig.return_annotation, BaseSchema) + + should_inject_request, request_param_name = False, None + for name in param_annotation_map: + if isinstance(param_annotation_map[name], Request): + should_inject_request, request_param_name = True, name + break + + decorated = func + + if should_inject_request: + decorated = inject_request(request_param_name)(decorated) + + if has_query_schema: + decorated = validate_querystring(param_annotation_map["query_args"])(decorated) + + if has_headers_schema: + decorated = validate_headers(param_annotation_map["headers"])(decorated) + + if has_request_schema: + decorated = validate_request(param_annotation_map["data"])(decorated) + + if has_response_schema: + decorated = validate_response(sig.return_annotation)(decorated) + + if authorizer: + decorated = authorizer(decorated) + + return self.route(uri, t.cast(t.List[str], methods), **route_kwargs)(decorated) + + return decorator + + def get(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["GET"], **kwargs) + + def post(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["POST"], **kwargs) + + def put(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["PUT"], **kwargs) + + def patch(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["PATCH"], **kwargs) + + def delete(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["DELETE"], **kwargs) diff --git a/src/quart_sqlalchemy/sim/views/util/decorator.py b/src/quart_sqlalchemy/sim/views/util/decorator.py new file mode 100644 index 0000000..55323d2 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/decorator.py @@ -0,0 +1,120 @@ +import typing as t +from dataclasses import asdict +from dataclasses import is_dataclass +from functools import wraps + +from humps import camelize +from humps import decamelize +from pydantic import BaseModel +from pydantic import ValidationError +from quart import current_app +from quart import request +from quart import Response +from quart_schema.typing import Model +from quart_schema.typing import ResponseReturnValue +from quart_schema.validation import QUART_SCHEMA_REQUEST_ATTRIBUTE +from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE +from quart_schema.validation import RequestSchemaValidationError +from quart_schema.validation import ResponseSchemaValidationError + + +def convert_model_result(func: t.Callable) -> t.Callable: + @wraps(func) + async def decorator(result: ResponseReturnValue) -> Response: + status_or_headers = None + headers = None + if isinstance(result, tuple): + value, status_or_headers, headers = result + (None,) * (3 - len(result)) + else: + value = result + + was_model = False + if is_dataclass(value): + dict_or_value = asdict(value) + was_model = True + elif isinstance(value, BaseModel): + dict_or_value = value.dict(by_alias=True) + was_model = True + else: + dict_or_value = value + + if was_model: + dict_or_value = camelize(dict_or_value) + + return await func((dict_or_value, status_or_headers, headers)) + + return decorator + + +def validate_request(model_class: Model) -> t.Callable: + def decorator(func: t.Callable) -> t.Callable: + setattr(func, QUART_SCHEMA_REQUEST_ATTRIBUTE, (model_class, None)) + + @wraps(func) + async def wrapper(*args, **kwargs): + data = await request.get_json() + data = decamelize(data) + + try: + model = model_class(**data) + except (TypeError, ValidationError) as error: + raise RequestSchemaValidationError(error) + else: + return await current_app.ensure_async(func)(*args, data=model, **kwargs) + + return wrapper + + return decorator + + +def validate_response(model_class: Model, status_code: int = 200) -> t.Callable: + def decorator(func): + schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) + schemas[status_code] = (model_class, None) + setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) + + @wraps(func) + async def wrapper(*args, **kwargs): + result = await current_app.ensure_async(func)(*args, **kwargs) + + status_or_headers = None + headers = None + if isinstance(result, tuple): + value, status_or_headers, headers = result + (None,) * (3 - len(result)) + else: + value = result + + status = 200 + if isinstance(status_or_headers, int): + status = int(status_or_headers) + + if status == status_code: + try: + if isinstance(value, dict): + model_value = model_class(**value) + elif type(value) == model_class: + model_value = value + else: + raise ResponseSchemaValidationError() + except ValidationError as error: + raise ResponseSchemaValidationError(error) + + return model_value, status, headers + else: + return result + + return wrapper + + return decorator + + +def inject_request(key: str): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + kwargs[key] = request._get_current_object() + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/quart_sqlalchemy/sim/web3.py b/src/quart_sqlalchemy/sim/web3.py new file mode 100644 index 0000000..e1a088b --- /dev/null +++ b/src/quart_sqlalchemy/sim/web3.py @@ -0,0 +1,153 @@ +import typing as t +from decimal import Decimal + +import typing_extensions as tx +import web3.providers +from ens import ENS +from eth_typing import AnyAddress +from eth_typing import ChecksumAddress +from eth_typing import HexStr +from eth_typing import Primitives +from eth_typing.abi import TypeStr +from quart import request +from quart.ctx import has_request_context +from web3.eth import Eth +from web3.geth import Geth +from web3.main import BaseWeb3 +from web3.module import Module +from web3.net import Net +from web3.providers import BaseProvider +from web3.types import Wei + + +""" +generate new key address pairing + +```zsh +python -c "from web3 import Web3; w3 = Web3(); acc = w3.eth.account.create(); print(f'private key={w3.to_hex(acc.key)}, account={acc.address}')" +``` +""" + + +class Web3Node(tx.Protocol): + eth: Eth + net: Net + geth: Geth + provider: BaseProvider + ens: ENS + + def is_connected(self) -> bool: + ... + + @staticmethod + def to_bytes( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> bytes: + ... + + @staticmethod + def to_int( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> int: + ... + + @staticmethod + def to_hex( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> HexStr: + ... + + @staticmethod + def to_text( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> str: + ... + + @staticmethod + def to_json(obj: t.Dict[t.Any, t.Any]) -> str: + ... + + @staticmethod + def to_wei(number: t.Union[int, float, str, Decimal], unit: str) -> Wei: + ... + + @staticmethod + def from_wei(number: int, unit: str) -> t.Union[int, Decimal]: + ... + + @staticmethod + def is_address(value: t.Any) -> bool: + ... + + @staticmethod + def is_checksum_address(value: t.Any) -> bool: + ... + + @staticmethod + def to_checksum_address(value: t.Union[AnyAddress, str, bytes]) -> ChecksumAddress: + ... + + @property + def api(self) -> str: + ... + + @staticmethod + def keccak( + primitive: t.Optional[Primitives] = None, + text: t.Optional[str] = None, + hexstr: t.Optional[HexStr] = None, + ) -> bytes: + ... + + @classmethod + def normalize_values( + cls, _w3: BaseWeb3, abi_types: t.List[TypeStr], values: t.List[t.Any] + ) -> t.List[t.Any]: + ... + + @classmethod + def solidity_keccak(cls, abi_types: t.List[TypeStr], values: t.List[t.Any]) -> bytes: + ... + + def attach_modules( + self, modules: t.Optional[t.Dict[str, t.Union[t.Type[Module], t.Sequence[t.Any]]]] + ) -> None: + ... + + def is_encodable(self, _type: TypeStr, value: t.Any) -> bool: + ... + + +def web3_node_factory(config): + if config["WEB3_PROVIDER_CLASS"] is web3.providers.HTTPProvider: + provider = config["WEB3_PROVIDER_CLASS"](config["WEB3_HTTPS_PROVIDER_URI"]) + return web3.Web3(provider) + + +class Web3: + node: Web3Node + + def __init__(self, node: Web3Node, default_network: str, default_chain: str): + self.node = node + self.default_network = default_network + self.default_chain = default_chain + + @property + def chain(self) -> str: + if has_request_context(): + return request.headers.get("x-web3-chain", self.default_chain).upper() + return self.default_chain + + @property + def network(self) -> str: + if has_request_context(): + return request.headers.get("x-web3-network", self.default_network).upper() + return self.default_network diff --git a/src/quart_sqlalchemy/sqla.py b/src/quart_sqlalchemy/sqla.py index 1df492a..b437625 100644 --- a/src/quart_sqlalchemy/sqla.py +++ b/src/quart_sqlalchemy/sqla.py @@ -10,47 +10,280 @@ import sqlalchemy.orm import sqlalchemy.util -from .bind import AsyncBind -from .bind import Bind -from .config import AsyncBindConfig -from .config import SQLAlchemyConfig +from quart_sqlalchemy.bind import AsyncBind +from quart_sqlalchemy.bind import Bind +from quart_sqlalchemy.config import AsyncBindConfig +from quart_sqlalchemy.config import SQLAlchemyConfig sa = sqlalchemy class SQLAlchemy: + """ + This manager class keeps things very simple by using a few configuration conventions. + + Configuration has been simplified down to base_class and binds. + + * Everything related to ORM mapping, DeclarativeBase, registry, MetaData, etc should be + configured by passing the a custom DeclarativeBase class as the base_class configuration + parameter. + + * Everything related to engine/session configuration should be configured by passing a + dictionary mapping string names to BindConfigs as the `binds` configuration parameter. + + BindConfig can be as simple as a dictionary containing a url key like so: + + bind_config = { + "default": {"url": "sqlite://"} + } + + But most use cases will require more than just a connection url, and divide core/engine + configuration from orm/session configuration which looks more like this: + + bind_config = { + "default": { + "engine": { + "url": "sqlite://" + }, + "session": { + "expire_on_commit": False + } + } + } + + Everything under `engine` will then be passed to `sqlalchemy.create_engine_from_config` and + everything under `session` will be passed to `sqlalchemy.orm.sessionmaker`. + + engine = sa.create_engine_from_config(bind_config.engine) + Session = sa.orm.sessionmaker(bind=engine, **bind_config.session) + + Config Examples: + + Simple URL: + db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + url="sqlite://" + ) + ) + ) + ) + + Shortcut for the above: + db = SQLAlchemy(SQLAlchemyConfig()) + + More complex configuration for engine and session both: + db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite://" + ), + session=dict( + expire_on_commit=False + ) + ) + ) + ) + ) + + Once instantiated, operations targetting all of the binds, aka metadata, like + `metadata.create_all` should be called from this class. Operations specific to a bind + should be called from that bind. This class has a few ways to get a specific bind. + + * To get a Bind, you can call `.get_bind(name)` on this class. The default bind can be + referenced at `.bind`. + + * To define an ORM model using the Base class attached to this class, simply inherit + from `.Base` + + db = SQLAlchemy(SQLAlchemyConfig()) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + * You can also decouple Base from SQLAlchemy with some dependency inversion: + from quart_sqlalchemy.model.mixins import DynamicArgsMixin, ReprMixin, TableNameMixin + + class Base(DynamicArgsMixin, ReprMixin, TableNameMixin): + __abstract__ = True + + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db = SQLAlchemy(SQLAlchemyConfig(bind_class=Base)) + + db.create_all() + + + Declarative Mapping using registry based decorator: + + db = SQLAlchemy(SQLAlchemyConfig()) + + @db.registry.mapped + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative with Imperative Table (Hybrid Declarative): + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + + Declarative using reflection to automatically build the table object: + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + autoload_with=db.bind.engine, + ) + + + Declarative Dataclass Mapping: + + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative Dataclass Mapping (using decorator): + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + @db.registry.mapped_as_dataclass + class User: + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Alternate Dataclass Provider Pattern: + + from pydantic.dataclasses import dataclass + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_, dataclass_callable=dataclass): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + Imperative style Mapping + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + user_table = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + post_table = sa.Table( + "post", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("title", sa.String, default="My post"), + sa.Column("user_id", sa.ForeignKey("user.id"), nullable=False), + ) + + class User: + pass + + class Post: + pass + + db.registry.map_imperatively( + User, + user_table, + properties={ + "posts": sa.orm.relationship(Post, back_populates="user") + } + ) + db.registry.map_imperatively( + Post, + post_table, + properties={ + "user": sa.orm.relationship(User, back_populates="posts", uselist=False) + } + ) + """ + config: SQLAlchemyConfig binds: t.Dict[str, t.Union[Bind, AsyncBind]] - Model: t.Type[sa.orm.DeclarativeBase] + Base: t.Type[sa.orm.DeclarativeBase] - def __init__(self, config: SQLAlchemyConfig, initialize: bool = True): + def __init__( + self, + config: t.Optional[SQLAlchemyConfig] = None, + initialize: bool = True, + ): self.config = config if initialize: self.initialize() - def initialize(self): - if issubclass(self.config.model_class, sa.orm.DeclarativeBase): - Model = self.config.model_class # type: ignore - else: + def initialize(self, config: t.Optional[SQLAlchemyConfig] = None): + if config is not None: + self.config = config + if self.config is None: + self.config = SQLAlchemyConfig.default() - class Model(self.config.model_class, sa.orm.DeclarativeBase): - pass + if issubclass(self.config.base_class, sa.orm.DeclarativeBase): + Base = self.config.base_class # type: ignore + else: + Base = type("Base", (self.config.base_class, sa.orm.DeclarativeBase), {}) - self.Model = Model + self.Base = Base - self.binds = {} - for name, bind_config in self.config.binds.items(): - is_async = isinstance(bind_config, AsyncBindConfig) - if is_async: - self.binds[name] = AsyncBind(bind_config, self.metadata) - else: - self.binds[name] = Bind(bind_config, self.metadata) + if not hasattr(self, "binds"): + self.binds = {} + for name, bind_config in self.config.binds.items(): + is_async = isinstance(bind_config, AsyncBindConfig) + factory = AsyncBind if is_async else Bind + self.binds[name] = factory(name, bind_config.engine.url, bind_config, self.metadata) - @classmethod - def default(cls): - return cls(SQLAlchemyConfig()) + def get_bind(self, bind: str = "default"): + return self.binds[bind] @property def bind(self) -> Bind: @@ -58,10 +291,11 @@ def bind(self) -> Bind: @property def metadata(self) -> sa.MetaData: - return self.Model.metadata + return self.Base.metadata - def get_bind(self, bind: str = "default"): - return self.binds[bind] + @property + def registry(self) -> sa.orm.registry: + return self.Base.registry def create_all(self, bind: str = "default"): return self.binds[bind].create_all() diff --git a/src/quart_sqlalchemy/testing/fake.py b/src/quart_sqlalchemy/testing/fake.py new file mode 100644 index 0000000..e69de29 diff --git a/src/quart_sqlalchemy/testing/signals.py b/src/quart_sqlalchemy/testing/signals.py new file mode 100644 index 0000000..d9c1cb9 --- /dev/null +++ b/src/quart_sqlalchemy/testing/signals.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import sqlalchemy +import sqlalchemy.orm +from blinker import Namespace +from quart.signals import AsyncNamespace + + +sa = sqlalchemy + +sync_signals = Namespace() +async_signals = AsyncNamespace() + + +load_test_fixtures = sync_signals.signal( + "quart-sqlalchemy.testing.fixtures.load.sync", + doc="""Fired to load test fixtures into a freshly instantiated test database. + + No default signal handlers exist for this signal as the logic is very application dependent. + + Example: + + @signals.framework_extension_load_fixtures.connect + def handle(sender: QuartSQLAlchemy, app: Quart): + bind = sender.get_bind("default") + with bind.Session() as session: + with session.begin(): + session.add_all( + [ + models.User(username="user1"), + models.User(username="user2"), + ] + ) + session.commit() + + Handler signature: + def handle(sender: QuartSQLAlchemy, app: Quart): + ... + """, +) diff --git a/src/quart_sqlalchemy/testing/transaction.py b/src/quart_sqlalchemy/testing/transaction.py index 507a963..d7ff873 100644 --- a/src/quart_sqlalchemy/testing/transaction.py +++ b/src/quart_sqlalchemy/testing/transaction.py @@ -23,13 +23,13 @@ def __init__(self, bind: "Bind", savepoint: bool = False): self.savepoint = savepoint self.bind = bind - def Session(self, **options): + def Session(self, **options: t.Any) -> sa.orm.Session: options.update(bind=self.connection) if self.savepoint: options.update(join_transaction_mode="create_savepoint") return self.bind.Session(**options) - def begin(self): + def open(self) -> None: self.connection = self.bind.engine.connect() self.trans = self.connection.begin() @@ -59,7 +59,7 @@ def close(self, exc: t.Optional[Exception] = None) -> None: ) def __enter__(self): - self.begin() + self.open() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -82,7 +82,7 @@ class AsyncTestTransaction(TestTransaction): def __init__(self, bind: "AsyncBind", savepoint: bool = False): super().__init__(bind, savepoint=savepoint) - async def begin(self): + async def open(self): self.connection = await self.bind.engine.connect() self.trans = await self.connection.begin() @@ -112,7 +112,7 @@ async def close(self, exc: t.Optional[Exception] = None) -> None: ) async def __aenter__(self): - await self.begin() + await self.open() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/src/quart_sqlalchemy/types.py b/src/quart_sqlalchemy/types.py index 2c5b96f..73e6c6f 100644 --- a/src/quart_sqlalchemy/types.py +++ b/src/quart_sqlalchemy/types.py @@ -7,6 +7,7 @@ import sqlalchemy.orm import sqlalchemy.sql import typing_extensions as tx +from sqlalchemy import SQLColumnExpression from sqlalchemy.orm.interfaces import ORMOption as _ORMOption from sqlalchemy.sql._typing import _ColumnExpressionArgument from sqlalchemy.sql._typing import _ColumnsClauseArgument @@ -15,11 +16,18 @@ sa = sqlalchemy + +class Empty: + pass + + +EmptyType = t.Type[Empty] + SessionT = t.TypeVar("SessionT", bound=sa.orm.Session) EntityT = t.TypeVar("EntityT", bound=sa.orm.DeclarativeBase) EntityIdT = t.TypeVar("EntityIdT", bound=t.Any) -ColumnExpr = _ColumnExpressionArgument +ColumnExpr = SQLColumnExpression Selectable = _ColumnsClauseArgument DMLTable = _DMLTableArgument ORMOption = _ORMOption @@ -40,3 +48,8 @@ SABind = t.Union[ sa.Engine, sa.Connection, sa.ext.asyncio.AsyncEngine, sa.ext.asyncio.AsyncConnection ] + + +class Operator(tx.Protocol): + def __call__(self, __a: object, __b: object) -> object: + ... diff --git a/tests/base.py b/tests/base.py index 1ff8756..8afa3fb 100644 --- a/tests/base.py +++ b/tests/base.py @@ -2,6 +2,7 @@ import random import typing as t +from copy import deepcopy from datetime import datetime import pytest @@ -10,8 +11,12 @@ from quart import Quart from sqlalchemy.orm import Mapped -from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy import Base from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model.mixins import ComparableMixin +from quart_sqlalchemy.model.mixins import DynamicArgsMixin +from quart_sqlalchemy.model.mixins import EagerDefaultsMixin +from quart_sqlalchemy.model.mixins import TableNameMixin from . import constants @@ -21,27 +26,156 @@ class SimpleTestBase: @pytest.fixture(scope="class") - def app(self, request): + def Base(self) -> t.Type[t.Any]: + return Base + + @pytest.fixture(scope="class") + def app_config(self, Base): + config = deepcopy(constants.simple_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + @pytest.fixture(scope="class") + def app(self, app_config, request): app = Quart(request.module.__name__) - app.config.from_mapping({"TESTING": True}) + app.config.from_mapping(app_config) + app.config["TESTING"] = True return app @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.simple_mapping_config) + def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: + db = QuartSQLAlchemy(app=app) - @pytest.fixture(scope="class") - def db(self, sqlalchemy_config, app: Quart) -> QuartSQLAlchemy: - return QuartSQLAlchemy(sqlalchemy_config, app) - # yield db - # db.drop_all() + yield db @pytest.fixture(scope="class") - def models(self, app: Quart, db: QuartSQLAlchemy) -> t.Mapping[str, t.Type[t.Any]]: - class Todo(db.Model): + def models( + self, app: Quart, db: QuartSQLAlchemy + ) -> t.Generator[t.Mapping[str, t.Type[t.Any]], None, None]: + class Todo(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional["User"]] = sa.orm.relationship( + back_populates="todos", lazy="noload", uselist=False + ) + + class User(db.Base): id: Mapped[int] = sa.orm.mapped_column( - sa.Identity(), primary_key=True, autoincrement=True + primary_key=True, + autoincrement=True, ) + name: Mapped[str] = sa.orm.mapped_column(default="default") + + created_at: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), + server_default=sa.FetchedValue(), + ) + + time_updated: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), + onupdate=sa.func.now(), + server_default=sa.FetchedValue(), + server_onupdate=sa.FetchedValue(), + ) + + todos: Mapped[t.List[Todo]] = sa.orm.relationship(lazy="noload", back_populates="user") + + yield dict(todo=Todo, user=User) + # We need to cleanup these objects that like to retain state beyond the fixture scope lifecycle + Base.registry.dispose() + Base.metadata.clear() + + @pytest.fixture(scope="class", autouse=True) + def create_drop_all(self, db: QuartSQLAlchemy, models): + db.create_all() + yield + db.drop_all() + + @pytest.fixture(scope="class") + def Todo(self, models: t.Mapping[str, t.Type[t.Any]]) -> t.Type[sa.orm.DeclarativeBase]: + return models["todo"] + + @pytest.fixture(scope="class") + def User(self, models: t.Mapping[str, t.Type[t.Any]]) -> t.Type[sa.orm.DeclarativeBase]: + return models["user"] + + @pytest.fixture(scope="class") + def _user_fixtures(self, User: t.Type[t.Any], Todo: t.Type[t.Any]): + users = [] + for i in range(5): + user = User(name=f"user: {i}") + for j in range(random.randint(0, 6)): + todo = Todo(title=f"todo: {j}") + user.todos.append(todo) + users.append(user) + return users + + @pytest.fixture(scope="class") + def _add_fixtures( + self, db: QuartSQLAlchemy, User: t.Type[t.Any], Todo: t.Type[t.Any], _user_fixtures + ) -> None: + with db.bind.Session() as s: + with s.begin(): + s.add_all(_user_fixtures) + + @pytest.fixture(scope="class", autouse=True) + def db_fixtures( + self, db: QuartSQLAlchemy, User: t.Type[t.Any], Todo: t.Type[t.Any], _add_fixtures + ) -> t.Dict[t.Type[t.Any], t.Sequence[t.Any]]: + with db.bind.Session() as s: + users = s.scalars(sa.select(User).options(sa.orm.selectinload(User.todos))).all() + todos = s.scalars(sa.select(Todo).options(sa.orm.selectinload(Todo.user))).all() + + return {User: users, Todo: todos} + + +class MixinTestBase: + default_mixins = ( + DynamicArgsMixin, + EagerDefaultsMixin, + TableNameMixin, + ComparableMixin, + ) + extra_mixins = () + + @pytest.fixture(scope="class") + def Base(self) -> t.Type[t.Any]: + return type( + "Base", + tuple(self.extra_mixins + self.default_mixins), + {"__abstract__": True}, + ) + + @pytest.fixture(scope="class") + def app_config(self, Base): + config = deepcopy(constants.simple_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + @pytest.fixture(scope="class") + def app(self, app_config, request): + app = Quart(request.module.__name__) + app.config.from_mapping(app_config) + app.config["TESTING"] = True + return app + + @pytest.fixture(scope="class") + def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: + db = QuartSQLAlchemy(app=app) + + yield db + + # It's very important to clear the class _instances dict before recreating binds with the same name. + # Bind._instances.clear() + + @pytest.fixture(scope="class") + def models( + self, app: Quart, db: QuartSQLAlchemy + ) -> t.Generator[t.Mapping[str, t.Type[t.Any]], None, None]: + class Todo(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) title: Mapped[str] = sa.orm.mapped_column(default="default") user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) @@ -49,9 +183,8 @@ class Todo(db.Model): back_populates="todos", lazy="noload", uselist=False ) - class User(db.Model): + class User(db.Base): id: Mapped[int] = sa.orm.mapped_column( - sa.Identity(), primary_key=True, autoincrement=True, ) @@ -71,7 +204,10 @@ class User(db.Model): todos: Mapped[t.List[Todo]] = sa.orm.relationship(lazy="noload", back_populates="user") - return dict(todo=Todo, user=User) + yield dict(todo=Todo, user=User) + # We need to cleanup these objects that like to retain state beyond the fixture scope lifecycle + Base.registry.dispose() + Base.metadata.clear() @pytest.fixture(scope="class", autouse=True) def create_drop_all(self, db: QuartSQLAlchemy, models): @@ -119,8 +255,10 @@ def db_fixtures( class AsyncTestBase(SimpleTestBase): @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.async_mapping_config) + def app_config(self, Base): + config = deepcopy(constants.async_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config @pytest.fixture(scope="class", autouse=True) async def create_drop_all(self, db: QuartSQLAlchemy, models) -> t.AsyncGenerator[None, None]: @@ -151,5 +289,31 @@ async def db_fixtures( class ComplexTestBase(SimpleTestBase): @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.complex_mapping_config) + def app_config(self, Base): + config = deepcopy(constants.complex_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + +# class CustomMixinTestBase(SimpleTestBase): +# default_mixins = ( +# DynamicArgsMixin, +# EagerDefaultsMixin, +# TableNameMixin, +# ) +# additional_mixins = () + +# @pytest.fixture(scope="class") +# def Base(self) -> t.Type[t.Any]: +# return type( +# "Base", +# tuple(self.additional_mixins + self.default_mixins), +# {"__abstract__": True}, +# ) + +# @pytest.fixture(scope="class") +# def app_config(self, Base): +# config = deepcopy(constants.simple_config) +# config["SQLALCHEMY_BASE_CLASS"] = Base +# config["TESTING"] = True +# return config diff --git a/tests/conftest.py b/tests/conftest.py index eefd7d7..e69de29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,49 +0,0 @@ -from __future__ import annotations - -import typing as t - -import pytest -import sqlalchemy -import sqlalchemy.orm -from quart import Quart -from sqlalchemy.orm import Mapped - -from quart_sqlalchemy import SQLAlchemyConfig -from quart_sqlalchemy.framework import QuartSQLAlchemy - -from . import constants - - -sa = sqlalchemy - - -@pytest.fixture(scope="session") -def app(request: pytest.FixtureRequest) -> Quart: - app = Quart(request.module.__name__) - app.config.from_mapping({"TESTING": True}) - return app - - -@pytest.fixture(scope="session") -def sqlalchemy_config(): - return SQLAlchemyConfig.parse_obj(constants.simple_mapping_config) - - -@pytest.fixture(scope="session") -def db(sqlalchemy_config, app: Quart) -> QuartSQLAlchemy: - return QuartSQLAlchemy(sqlalchemy_config, app) - - -@pytest.fixture(name="Todo", scope="session") -def _todo_fixture( - app: Quart, db: QuartSQLAlchemy -) -> t.Generator[t.Type[sa.orm.DeclarativeBase], None, None]: - class Todo(db.Model): - id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True) - title: Mapped[str] = sa.orm.mapped_column(default="default") - - db.create_all() - - yield Todo - - db.drop_all() diff --git a/tests/constants.py b/tests/constants.py index 4fb8240..6a2277c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,19 +1,18 @@ from quart_sqlalchemy import Base -simple_mapping_config = { - "model_class": Base, - "binds": { +simple_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, } }, + "SQLALCHEMY_BASE_CLASS": Base, } -complex_mapping_config = { - "model_class": Base, - "binds": { +complex_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, @@ -28,14 +27,15 @@ "session": {"expire_on_commit": False}, }, }, + "SQLALCHEMY_BASE_CLASS": Base, } -async_mapping_config = { - "model_class": Base, - "binds": { +async_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, - } + }, }, + "SQLALCHEMY_BASE_CLASS": Base, } diff --git a/tests/integration/concurrency/__init__.py b/tests/integration/concurrency/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/concurrency/with_for_update.py b/tests/integration/concurrency/with_for_update.py new file mode 100644 index 0000000..a9fa552 --- /dev/null +++ b/tests/integration/concurrency/with_for_update.py @@ -0,0 +1,103 @@ +import logging +import threading +import time + +import pytest +import sqlalchemy +import sqlalchemy.orm + + +sa = sqlalchemy + +logging.basicConfig(level=logging.DEBUG) +logging.getLogger("sqlalchemy").setLevel(logging.INFO) + +log = logging.getLogger(__name__) + + +class Base(sa.orm.DeclarativeBase): + pass + + +class Thing(Base): + __tablename__ = "things" + + id = sa.Column(sa.Integer, primary_key=True) + status = sa.Column(sa.String) + + +@pytest.fixture(scope="module") +def engine(): + engine = sa.create_engine("sqlite:///") + # engine = sa.create_engine("postgresql+psycopg2://spikes:sesame@localhost/spikes") + Base.metadata.create_all(engine) + + yield engine + + Base.metadata.drop_all(engine) + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn + + +@pytest.fixture +def db(connection): + transaction = connection.begin() + session = sa.orm.Session(bind=connection) + + # now we can even `.commit()` such session + yield session + + session.close() + transaction.rollback() + + +def test_select_for_update(engine): + # scoped_db = scoped_session(sessionmaker(bind=connection)) + scoped_db = sa.orm.scoped_session(sa.orm.sessionmaker(bind=engine)) + db = scoped_db() + db.add(Thing(status="old")) + db.commit() + + def first(event, sess_factory, status): + sess = sess_factory() + # thing = sess.query(Thing).get(1) + thing = sess.query(Thing).with_for_update().get(1) + event.set() # poke second thread + log.debug("Make him wait for a while") + time.sleep(0.263) + thing.status = status + sess.commit() + log.debug("Done!") + # it is always better to explicitly `.remove()` scoped sessions, but + # in this case it is not necessary because it will be garbage-collected + # sess_factory.remove() + + def second(event, sess_factory, status): + event.wait() # ensure we are called in the right moment + sess = sess_factory() + # thing = sess.query(Thing).get(1) + thing = sess.query(Thing).with_for_update().get(1) + thing.status = status + sess.commit() + + event = threading.Event() + th1 = threading.Thread(target=first, args=(event, scoped_db, "new")) + th2 = threading.Thread(target=second, args=(event, scoped_db, "brand_new")) + + th1.start() + th2.start() + + th1.join() + th2.join() + + # assert db.query(Thing).filter_by(id=1).one().status == 'new' + t = db.query(Thing).get(1) + # it is only mandatory to remove session here, seems like it is not + # garbage-collected becasue it is in `assert` statement (not sure about that) + scoped_db.remove() + + assert t.status == "brand_new" diff --git a/tests/integration/framework/smoke_test.py b/tests/integration/framework/smoke_test.py index 682d0e4..7154fd7 100644 --- a/tests/integration/framework/smoke_test.py +++ b/tests/integration/framework/smoke_test.py @@ -40,7 +40,7 @@ def test_simple_transactional_orm_flow(self, db: QuartSQLAlchemy, Todo: t.Any): def test_simple_transactional_core_flow(self, db: QuartSQLAlchemy, Todo: t.Any): with db.bind.engine.connect() as conn: with conn.begin(): - result = conn.execute(sa.insert(Todo)) + result = conn.execute(sa.insert(Todo).values(title="default")) insert_row = result.inserted_primary_key select_row = conn.execute(sa.select(Todo).where(Todo.id == insert_row.id)).one() @@ -56,15 +56,3 @@ def test_simple_transactional_core_flow(self, db: QuartSQLAlchemy, Todo: t.Any): with db.bind.engine.connect() as conn: with pytest.raises(sa.exc.NoResultFound): conn.execute(sa.select(Todo).where(Todo.id == insert_row.id)).one() - - def test_orm_models_comparable(self, db: QuartSQLAlchemy, Todo: t.Any): - with db.bind.Session() as s: - with s.begin(): - todo = Todo() - s.add(todo) - s.flush() - s.refresh(todo) - - with db.bind.Session() as s: - select_todo = s.scalars(sa.select(Todo).where(Todo.id == todo.id)).one() - assert todo == select_todo diff --git a/tests/integration/model/mixins_test.py b/tests/integration/model/mixins_test.py index a88b4d9..43928e6 100644 --- a/tests/integration/model/mixins_test.py +++ b/tests/integration/model/mixins_test.py @@ -8,21 +8,26 @@ from sqlalchemy.orm import Mapped from quart_sqlalchemy import SQLAlchemy -from quart_sqlalchemy.model import Base -from quart_sqlalchemy.model import SoftDeleteMixin +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model.mixins import ComparableMixin +from quart_sqlalchemy.model.mixins import RecursiveDictMixin +from quart_sqlalchemy.model.mixins import ReprMixin +from quart_sqlalchemy.model.mixins import SimpleDictMixin +from quart_sqlalchemy.model.mixins import SoftDeleteMixin +from quart_sqlalchemy.model.mixins import TotalOrderMixin -from ...base import SimpleTestBase +from ... import base sa = sqlalchemy -class TestSoftDeleteFeature(SimpleTestBase): - @pytest.fixture - def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[Base], None, None]: - class Post(SoftDeleteMixin, db.Model): - id: Mapped[int] = sa.orm.mapped_column(primary_key=True) - title: Mapped[str] = sa.orm.mapped_column() +class TestSoftDeleteFeature(base.MixinTestBase): + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(SoftDeleteMixin, db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") @@ -53,3 +58,113 @@ def test_inactive_filtered(self, db: SQLAlchemy, Post: t.Type[t.Any]): assert select_post.id == post.id assert select_post.is_active is False + + +class TestComparableMixin(base.MixinTestBase): + extra_mixins = (TotalOrderMixin,) + + def test_orm_models_comparable(self, db: QuartSQLAlchemy, Todo: t.Any): + assert ComparableMixin in self.default_mixins + + with db.bind.Session() as s: + with s.begin(): + todos = [Todo() for _ in range(5)] + s.add_all(todos) + + with db.bind.Session() as s: + todos = s.scalars(sa.select(Todo).order_by(Todo.id)).all() + + todo1, todo2, *_ = todos + assert todo1 < todo2 + + +class TestReprMixin(base.MixinTestBase): + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(ReprMixin, db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_generates_repr(self, db: QuartSQLAlchemy, Post: t.Any): + with db.bind.Session() as s: + with s.begin(): + post = Post() + s.add(post) + s.flush() + s.refresh(post) + + assert repr(post) == f"<{type(post).__name__} {post.id}>" + + +class TestSimpleDictMixin(base.MixinTestBase): + extra_mixins = (SimpleDictMixin,) + + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_converts_model_to_dict(self, db: QuartSQLAlchemy, Post: t.Any, User: t.Any): + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User)).first() + post = Post(user=user) + s.add(post) + s.flush() + s.refresh(post.user) + + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User).options(sa.orm.selectinload(User.posts))).first() + + data = user.to_dict() + + for field in data: + assert data[field] == getattr(user, field) + + +class TestRecursiveMixin(base.MixinTestBase): + extra_mixins = (RecursiveDictMixin,) + + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_converts_model_to_dict(self, db: QuartSQLAlchemy, Post: t.Any, User: t.Any): + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User)).first() + post = Post(user=user) + s.add(post) + s.flush() + s.refresh(post.user) + + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User).options(sa.orm.selectinload(User.posts))).first() + + data = user.to_dict() + + for col in sa.inspect(user).mapper.columns: + assert data[col.name] == getattr(user, col.name) diff --git a/tests/integration/model/model_test.py b/tests/integration/model/model_test.py new file mode 100644 index 0000000..651d908 --- /dev/null +++ b/tests/integration/model/model_test.py @@ -0,0 +1,59 @@ +from datetime import datetime + +import pytest +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.ext +import sqlalchemy.ext.asyncio +import sqlalchemy.orm +import sqlalchemy.util +from sqlalchemy.orm import Mapped + +from quart_sqlalchemy import Base +from quart_sqlalchemy import Bind +from quart_sqlalchemy import SQLAlchemy +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.model.model import BaseMixins + + +sa = sqlalchemy + + +class TestSQLAlchemyWithCustomModelClass: + def test_base_class_with_declarative_preserves_class_and_table_metadata(self): + """This is nice to have as it decouples quart and quart_sqlalchemy from the data + models themselves. + """ + + class User(Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + db.create_all() + + with db.bind.Session() as s: + with s.begin(): + user = User() + s.add(user) + s.flush() + s.refresh(user) + + Base.registry.dispose() + Bind._instances.clear() + + def test_sqla_class_adds_declarative_base_when_missing_from_base_class(self): + db = SQLAlchemy(SQLAlchemyConfig(base_class=BaseMixins)) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + + db.create_all() + + with db.bind.Session() as s: + with s.begin(): + user = User() + s.add(user) + s.flush() + s.refresh(user) diff --git a/tests/integration/retry_test.py b/tests/integration/retry_test.py index 4e4d60b..a4bb6a6 100644 --- a/tests/integration/retry_test.py +++ b/tests/integration/retry_test.py @@ -46,7 +46,7 @@ def test_retrying_session(self, db: SQLAlchemy, Todo: t.Type[t.Any], mocker): # s.commit() def test_retrying_session_class(self, db: SQLAlchemy, Todo: t.Type[t.Any], mocker): - class Unique(db.Model): + class Unique(db.Base): id: Mapped[int] = sa.orm.mapped_column( sa.Identity(), primary_key=True, autoincrement=True ) @@ -57,7 +57,7 @@ class Unique(db.Model): db.create_all() with retrying_session(db.bind) as s: - todo = Todo(title="hello") + todo = Unique(name="hello") s.add(todo) diff --git a/workspace.code-workspace b/workspace.code-workspace index 10a11a5..aa9e12e 100644 --- a/workspace.code-workspace +++ b/workspace.code-workspace @@ -13,6 +13,7 @@ "source.organizeImports": true } }, - "esbonio.sphinx.confDir": "" + "esbonio.sphinx.confDir": "", + "python.linting.pylintEnabled": false } }