diff --git a/examples/decorators/provide_session.py b/examples/decorators/provide_session.py index f67f2b0..469cfc4 100644 --- a/examples/decorators/provide_session.py +++ b/examples/decorators/provide_session.py @@ -46,11 +46,10 @@ def decorator(func: t.Callable[..., RT]) -> t.Callable[..., RT]: def wrapper(*args, **kwargs) -> RT: if "session" in kwargs or session_args_idx < len(args): return func(*args, **kwargs) - else: - bind = Bind.get_instance(bind_name) + bind = Bind.get_instance(bind_name) - with create_session(bind) as session: - return func(*args, session=session, **kwargs) + with create_session(bind) as session: + return func(*args, session=session, **kwargs) return wrapper diff --git a/examples/usrsrv/component/repository.py b/examples/usrsrv/component/repository.py index 44dffcd..2d1d64b 100644 --- a/examples/usrsrv/component/repository.py +++ b/examples/usrsrv/component/repository.py @@ -53,10 +53,12 @@ def __init__(self, session: Session): self._query = select(EntityMapper) def get(self, entity_id: EntityID) -> Entity: - dto = self._session.scalars(self._query.filter_by(uuid=entity_id)).one_or_none() - if not dto: + if dto := self._session.scalars( + self._query.filter_by(uuid=entity_id) + ).one_or_none(): + return Entity(dto) + else: raise NotFound(entity_id) - return Entity(dto) def save(self, entity: Entity) -> None: self._session.add(entity.dto) diff --git a/src/quart_sqlalchemy/config.py b/src/quart_sqlalchemy/config.py index 0190160..b29aefe 100644 --- a/src/quart_sqlalchemy/config.py +++ b/src/quart_sqlalchemy/config.py @@ -146,8 +146,9 @@ class EngineConfig(ConfigBase): @root_validator def scrub_execution_options(cls, values): if "execution_options" in values: - execute_options = values["execution_options"].dict(exclude_defaults=True) - if execute_options: + if execute_options := values["execution_options"].dict( + exclude_defaults=True + ): values["execution_options"] = execute_options return values diff --git a/src/quart_sqlalchemy/framework/extension.py b/src/quart_sqlalchemy/framework/extension.py index 70f0737..7c6dbad 100644 --- a/src/quart_sqlalchemy/framework/extension.py +++ b/src/quart_sqlalchemy/framework/extension.py @@ -14,7 +14,7 @@ def __init__( config: t.Optional[SQLAlchemyConfig] = None, app: t.Optional[Quart] = None, ): - initialize = False if config is None else True + initialize = config is not None super().__init__(config, initialize=initialize) if app is not None: diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index d08f446..c08eb58 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -301,7 +301,7 @@ def accumulate_mappings(class_, attribute) -> t.Dict[str, t.Any]: if base_class is class_: continue args = getattr(base_class, attribute, {}) - accumulated.update(args) + accumulated |= args return accumulated @@ -316,7 +316,7 @@ def accumulate_tuples_with_mapping(class_, attribute) -> t.Sequence[t.Any]: args = getattr(base_class, attribute, ()) for arg in args: if isinstance(arg, t.Mapping): - accumulated_map.update(arg) + accumulated_map |= arg else: accumulated_args.append(arg) diff --git a/src/quart_sqlalchemy/session.py b/src/quart_sqlalchemy/session.py index 36db6c3..40bd61e 100644 --- a/src/quart_sqlalchemy/session.py +++ b/src/quart_sqlalchemy/session.py @@ -45,17 +45,17 @@ def provide_global_contextual_session(func): @wraps(func) def wrapper(self, *args, **kwargs): session_in_args = any( - [isinstance(arg, (sa.orm.Session, sa.ext.asyncio.AsyncSession)) for arg in args] + isinstance(arg, (sa.orm.Session, sa.ext.asyncio.AsyncSession)) + for arg in args ) session_in_kwargs = "session" in kwargs session_provided = session_in_args or session_in_kwargs if session_provided: return func(self, *args, **kwargs) - else: - session = session_proxy() + session = session_proxy() - return func(self, session, *args, **kwargs) + return func(self, session, *args, **kwargs) return wrapper diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py index dece2ba..95f1ed1 100644 --- a/src/quart_sqlalchemy/sim/auth.py +++ b/src/quart_sqlalchemy/sim/auth.py @@ -228,7 +228,7 @@ def auth_endpoint_security(self): results = self.authenticator.enforce(security_schemes, session) authorized_credentials = {} for result in results: - authorized_credentials.update(result) + authorized_credentials |= result g.authorized_credentials = authorized_credentials diff --git a/src/quart_sqlalchemy/sim/config.py b/src/quart_sqlalchemy/sim/config.py index d743ee7..2513480 100644 --- a/src/quart_sqlalchemy/sim/config.py +++ b/src/quart_sqlalchemy/sim/config.py @@ -16,21 +16,22 @@ sa = sqlalchemy + + class AppSettings(BaseSettings): + class Config: env_file = ".env", ".secrets.env" LOAD_BLUEPRINTS: t.List[str] = Field( - default_factory=lambda: list(("quart_sqlalchemy.sim.views.api",)) + default_factory=lambda: ["quart_sqlalchemy.sim.views.api"] ) LOAD_EXTENSIONS: t.List[str] = Field( - default_factory=lambda: list( - ( - "quart_sqlalchemy.sim.db.db", - "quart_sqlalchemy.sim.app.schema", - "quart_sqlalchemy.sim.auth.auth", - ) - ) + default_factory=lambda: [ + "quart_sqlalchemy.sim.db.db", + "quart_sqlalchemy.sim.app.schema", + "quart_sqlalchemy.sim.auth.auth", + ] ) SECURITY_SCHEMES: t.Dict[str, SecuritySchemeBase] = Field( default_factory=lambda: { @@ -52,4 +53,5 @@ class Config: WEB3_HTTPS_PROVIDER_URI: str = Field(env="WEB3_HTTPS_PROVIDER_URI") + settings = AppSettings() diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py index 9634815..0f75ecb 100644 --- a/src/quart_sqlalchemy/sim/db.py +++ b/src/quart_sqlalchemy/sim/db.py @@ -51,19 +51,13 @@ def process_bind_param(self, value, dialect): """Data going into to the database will be transformed by this method. See ``ObjectID`` for the design and rational for this. """ - if value is None: - return None - - return ObjectID(value).decode() + return None if value is None else ObjectID(value).decode() def process_result_value(self, value, dialect): """Data going out from the database will be explicitly casted to the ``ObjectID``. """ - if value is None: - return None - - return ObjectID(value) + return None if value is None else ObjectID(value) class MyBase(BaseMixins, sa.orm.DeclarativeBase): diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index bd3bc12..1457a58 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -99,16 +99,11 @@ def update_app_name_by_id(self, session: sa.orm.Session, magic_client_id, app_na """ client = self.logic.MagicClient.update_by_id(session, magic_client_id, app_name=app_name) - if not client: - return None - - return client.app_name + return client.app_name if client else None @provide_global_contextual_session def update_by_id(self, session: sa.orm.Session, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id(session, magic_client_id, **kwargs) - - return client + return self.logic.MagicClient.update_by_id(session, magic_client_id, **kwargs) @provide_global_contextual_session def set_inactive_by_id(self, session: sa.orm.Session, magic_client_id): @@ -341,11 +336,10 @@ def sync_auth_wallet( wallet_type: t.Optional[WalletType] = None, ): with session.begin_nested(): - existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + if existing_wallet := self.logic.AuthWallet.get_by_auth_user_id( session, auth_user_id, - ) - if existing_wallet: + ): raise RuntimeError("WalletExistsForNetworkAndWalletType") return self.logic.AuthWallet.add( diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index 626adbf..3402edb 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -31,13 +31,12 @@ class LogicMeta(type): def __init__(cls, name, bases, cls_dict): if not hasattr(cls, "_registry"): cls._registry = {} - else: - if cls.__name__ not in cls._ignore: - model = getattr(cls, "model", None) - if model is not None: - name = model.__name__ + elif cls.__name__ not in cls._ignore: + model = getattr(cls, "model", None) + if model is not None: + name = model.__name__ - cls._registry[name] = cls() + cls._registry[name] = cls() super().__init__(name, bases, cls_dict) @@ -161,10 +160,7 @@ def add_by_email_and_client_id( user_type=user_type, ): logger.exception( - "User duplication for email: {} (client_id: {})".format( - email, - client_id, - ), + f"User duplication for email: {email} (client_id: {client_id})" ) raise DuplicateAuthUser() @@ -176,10 +172,7 @@ def add_by_email_and_client_id( **kwargs, ) logger.info( - "New auth user (id: {}) created by email (client_id: {})".format( - row.id, - client_id, - ), + f"New auth user (id: {row.id}) created by email (client_id: {client_id})" ) return row @@ -203,7 +196,7 @@ def add_by_client_id( date_verified=datetime.utcnow() if is_verified else None, ) logger.info( - "New auth user (id: {}) created by (client_id: {})".format(row.id, client_id), + f"New auth user (id: {row.id}) created by (client_id: {client_id})" ) return row @@ -409,7 +402,7 @@ def add( management_type=None, auth_user_id=None, ): - new_row = self._repository.add( + return self._repository.add( session, auth_user_id=auth_user_id, public_address=public_address, @@ -419,8 +412,6 @@ def add( network=network, ) - return new_row - @provide_global_contextual_session def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): return self._repository.get_by_id( @@ -441,10 +432,7 @@ def get_by_public_address(self, session, public_address, network=None, is_active row = self._repository.get_by(session, filters=filters, allow_inactive=not is_active) - if not row: - return None - - return one(row) + return one(row) if row else None @provide_global_contextual_session def get_by_auth_user_id( @@ -470,10 +458,7 @@ def get_by_auth_user_id( session, filters=filters, join_list=join_list, allow_inactive=not is_active ) - if not rows: - return [] - - return rows + return rows or [] @provide_global_contextual_session def update_by_id(self, session, model_id, **kwargs): diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index 8e0a43d..521bd3f 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -59,11 +59,7 @@ def get_by( join_list = join_list or () - if order_by_clause is not None: - order_by_clause = (order_by_clause,) - else: - order_by_clause = () - + order_by_clause = (order_by_clause, ) if order_by_clause is not None else () return self._adapted.select( session, conditions=filters, @@ -129,9 +125,7 @@ def count_by( else: selectables = [sa.label("count", sa.func.count(self.model.id))] - for group in group_by: - selectables.append(group.expression) - + selectables.extend(group.expression for group in group_by) result = self._adapted.select(session, selectables, conditions=filters, group_by=group_by) return result.all() @@ -178,15 +172,16 @@ def yield_by_chunk( ): filters = filters or () join_list = join_list or () - results = self._adapted.select( + yield from self._adapted.select( session, conditions=filters, - options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], + options=[ + sa.orm.selectinload(getattr(self.model, attr)) + for attr in join_list + ], include_inactive=allow_inactive, yield_by_chunk=chunk_size, ) - for result in results: - yield result class PydanticScalarResult(sa.ScalarResult, t.Generic[ModelSchemaT]): @@ -243,9 +238,7 @@ def insert( create_data = create_schema.dict() result = super().insert(session, create_data) - if sqla_model: - return result - return self.model_schema.from_orm(result) + return result if sqla_model else self.model_schema.from_orm(result) def update( self, @@ -266,9 +259,7 @@ def update( session.flush() session.refresh(existing) - if sqla_model: - return existing - return self.model_schema.from_orm(existing) + return existing if sqla_model else self.model_schema.from_orm(existing) def get( self, @@ -291,9 +282,7 @@ def get( if row is None: return - if sqla_model: - return row - return self.model_schema.from_orm(row) + return row if sqla_model else self.model_schema.from_orm(row) def select( self, diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index e679d61..19eb161 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -55,10 +55,7 @@ class Config: @validator("status") def set_status_by_error_code(cls, v, values): - error_code = values.get("error_code") - if error_code: - return "failed" - return "ok" + return "failed" if (error_code := values.get("error_code")) else "ok" class MagicClientSchema(BaseSchema): diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py index d4d33d1..0a8582f 100644 --- a/src/quart_sqlalchemy/sim/util.py +++ b/src/quart_sqlalchemy/sim/util.py @@ -64,7 +64,7 @@ def __lt__(self, other): return False def __hash__(self): - return hash(tuple([self._encoded_id, self._decoded_id])) + return hash((self._encoded_id, self._decoded_id)) def __str__(self): return "{encoded_id}".format(encoded_id=self._encoded_id) @@ -88,10 +88,7 @@ def encode(self): return self._encoded_id def _decode(self, value): - if isinstance(value, int): - return value - else: - return self.hashids.decode(value)[0] + return value if isinstance(value, int) else self.hashids.decode(value)[0] def decode(self): return self._decoded_id diff --git a/tests/base.py b/tests/base.py index 8afa3fb..d41d85b 100644 --- a/tests/base.py +++ b/tests/base.py @@ -44,9 +44,7 @@ def app(self, app_config, request): @pytest.fixture(scope="class") def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: - db = QuartSQLAlchemy(app=app) - - yield db + yield QuartSQLAlchemy(app=app) @pytest.fixture(scope="class") def models( @@ -163,9 +161,7 @@ def app(self, app_config, request): @pytest.fixture(scope="class") def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: - db = QuartSQLAlchemy(app=app) - - yield db + yield QuartSQLAlchemy(app=app) # It's very important to clear the class _instances dict before recreating binds with the same name. # Bind._instances.clear()