From eb61e24d1000badda74b05cecc96f88ffa2ecb1d Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Fri, 31 Mar 2023 23:15:01 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- src/quart_sqlalchemy/model/mixins.py | 6 +- src/quart_sqlalchemy/sim/auth.py | 2 +- src/quart_sqlalchemy/sim/db.py | 10 +--- src/quart_sqlalchemy/sim/handle.py | 63 +++++---------------- src/quart_sqlalchemy/sim/logic.py | 71 +++++++++--------------- src/quart_sqlalchemy/sim/repo_adapter.py | 34 ++++-------- src/quart_sqlalchemy/sim/schema.py | 5 +- src/quart_sqlalchemy/sim/util.py | 7 +-- src/quart_sqlalchemy/sqla.py | 2 +- 9 files changed, 60 insertions(+), 140 deletions(-) diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index 3dc3834..ea53af3 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -63,7 +63,7 @@ def __lt__(self, other): if column.primary_key: continue - if not (getattr(self, key) == getattr(other, key)): + if getattr(self, key) != getattr(other, key): return False return True @@ -263,7 +263,7 @@ def accumulate_mappings(class_, attribute) -> t.Dict[str, t.Any]: if base_class is class_: continue args = getattr(base_class, attribute, {}) - accumulated.update(args) + accumulated |= args return accumulated @@ -278,7 +278,7 @@ def accumulate_tuples_with_mapping(class_, attribute) -> t.Sequence[t.Any]: args = getattr(base_class, attribute, ()) for arg in args: if isinstance(arg, t.Mapping): - accumulated_map.update(arg) + accumulated_map |= arg else: accumulated_args.append(arg) diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py index 80a908a..7925fa2 100644 --- a/src/quart_sqlalchemy/sim/auth.py +++ b/src/quart_sqlalchemy/sim/auth.py @@ -249,7 +249,7 @@ def auth_endpoint_security(self): results = self.authenticator.enforce(security_schemes, session) authorized_credentials = {} for result in results: - authorized_credentials.update(result) + authorized_credentials |= result g.authorized_credentials = authorized_credentials diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py index e7677f6..0b49e8f 100644 --- a/src/quart_sqlalchemy/sim/db.py +++ b/src/quart_sqlalchemy/sim/db.py @@ -32,19 +32,13 @@ def process_bind_param(self, value, dialect): """Data going into to the database will be transformed by this method. See ``ObjectID`` for the design and rational for this. """ - if value is None: - return None - - return ObjectID(value).decode() + return None if value is None else ObjectID(value).decode() def process_result_value(self, value, dialect): """Data going out from the database will be explicitly casted to the ``ObjectID``. """ - if value is None: - return None - - return ObjectID(value) + return None if value is None else ObjectID(value) class MyBase(Base): diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index e079da4..189fe42 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -145,15 +145,12 @@ def update_app_name_by_id(self, magic_client_id, app_name): self.session, magic_client_id, app_name=app_name ) - if not client: - return None - - return client.app_name + return client.app_name if client else None def update_by_id(self, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id(self.session, magic_client_id, **kwargs) - - return client + return self.logic.MagicClient.update_by_id( + self.session, magic_client_id, **kwargs + ) def set_inactive_by_id(self, magic_client_id): """ @@ -272,37 +269,17 @@ def get_or_create_by_email_and_client_id( client_id, user_type=EntityType.MAGIC.value, ): - auth_user = self.logic.AuthUser.get_by_email_and_client_id( + return self.logic.AuthUser.get_by_email_and_client_id( self.session, email, client_id, user_type=user_type, + ) or self.logic.AuthUser.add_by_email_and_client_id( + self.session, + client_id, + email=email, + user_type=user_type, ) - if not auth_user: - # try: - # email = enhanced_email_validation( - # email, - # source=MAGIC, - # # So we don't affect sign-up. - # silence_network_error=True, - # ) - # except ( - # EnhanceEmailValidationError, - # EnhanceEmailSuggestionError, - # ) as e: - # logger.warning( - # "Email Start Attempt.", - # exc_info=True, - # ) - # raise EnhancedEmailValidation(error_message=str(e)) from e - - auth_user = self.logic.AuthUser.add_by_email_and_client_id( - self.session, - client_id, - email=email, - user_type=user_type, - ) - return auth_user def get_by_id_and_validate_exists(self, auth_user_id): """This function helps formalize how a non-existent auth user should be handled.""" @@ -327,14 +304,12 @@ def create_verified_user( email, user_type=user_type, ).id - auth_user = self.logic.AuthUser._update_by_id( + return self.logic.AuthUser._update_by_id( self.session, auid, date_verified=datetime.utcnow(), ) - return auth_user - # def get_auth_user_from_public_address(self, public_address): # wallet = self.logic.AuthWallet.get_by_public_address(public_address) @@ -458,7 +433,7 @@ def search_by_client_id_and_substring( if not isinstance(substring, str) or len(substring) < 3: raise InvalidSubstringError() - auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( + return self.logic.AuthUser.get_by_client_id_with_substring_search( self.session, client_id, substring, @@ -467,15 +442,6 @@ def search_by_client_id_and_substring( # join_list=join_list, ) - # mfa_enablements = self.auth_user_mfa_handler.is_active_batch( - # [auth_user.id for auth_user in auth_users], - # ) - # for auth_user in auth_users: - # if mfa_enablements[auth_user.id] is False: - # auth_user.mfa_methods = [] - - return auth_users - def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): if auth_user is None and auth_user_id is None: raise Exception("At least one argument needed: auth_user_id or auth_user.") @@ -575,11 +541,10 @@ def sync_auth_wallet( encrypted_private_address, wallet_management_type, ): - existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + if existing_wallet := self.logic.AuthWallet.get_by_auth_user_id( self.session, auth_user_id, - ) - if existing_wallet: + ): raise RuntimeError("WalletExistsForNetworkAndWalletType") return self.logic.AuthWallet.add( diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index ceb6ffd..f04ad49 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -230,10 +230,7 @@ def _get_or_add_by_phone_number_and_client_id( provenance=Provenance.SMS, ) logger.info( - "New auth user (id: {}) created by phone number (client_id: {})".format( - row.id, - client_id, - ), + f"New auth user (id: {row.id}) created by phone number (client_id: {client_id})" ) return row @@ -261,10 +258,7 @@ def _add_by_email_and_client_id( user_type=user_type, ): logger.exception( - "User duplication for email: {} (client_id: {})".format( - email, - client_id, - ), + f"User duplication for email: {email} (client_id: {client_id})" ) raise DuplicateAuthUser() @@ -276,10 +270,7 @@ def _add_by_email_and_client_id( **kwargs, ) logger.info( - "New auth user (id: {}) created by email (client_id: {})".format( - row.id, - client_id, - ), + f"New auth user (id: {row.id}) created by email (client_id: {client_id})" ) return row @@ -305,10 +296,7 @@ def _add_by_client_id( date_verified=datetime.utcnow() if is_verified else None, ) logger.info( - "New auth user (id: {}) created by (client_id: {})".format( - row.id, - client_id, - ), + f"New auth user (id: {row.id}) created by (client_id: {client_id})" ) return row @@ -559,20 +547,21 @@ def _get_by_client_ids_and_user_type( offset=None, limit=None, ): - if not client_ids: - return [] - - return self._repository.get_by( - session, - filters=[ - auth_user_model.client_id.in_(client_ids), - auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712, - auth_user_model.date_verified != None, - ], - offset=offset, - limit=limit, - order_by_clause=auth_user_model.id.desc(), + return ( + self._repository.get_by( + session, + filters=[ + auth_user_model.client_id.in_(client_ids), + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712, + auth_user_model.date_verified != None, + ], + offset=offset, + limit=limit, + order_by_clause=auth_user_model.id.desc(), + ) + if client_ids + else [] ) # get_by_client_ids_and_user_type = with_db_session(ro=True)( @@ -597,7 +586,7 @@ def _get_by_client_id_with_substring_search( or_( auth_user_model.provenance == Provenance.SMS, auth_user_model.provenance == Provenance.LINK, - auth_user_model.provenance == None, # noqa: E711 + auth_user_model.provenance is None, ), or_( auth_user_model.phone_number.contains(substring), @@ -691,7 +680,8 @@ def _get_by_email_for_interop( .options(contains_eager(auth_user_model.wallets)) .join( auth_user_model.magic_client.and_( - magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, + magic_client_model.connect_interop + == ConnectInteropStatus.ENABLED, ), ) .options(contains_eager(auth_user_model.magic_client)) @@ -708,8 +698,7 @@ def _get_by_email_for_interop( .filter( auth_user_model.email == email, auth_user_model.user_type == EntityType.MAGIC.value, - # auth_user_model.is_active == 1, - auth_user_model.linked_primary_auth_user_id == None, # noqa: E711 + auth_user_model.linked_primary_auth_user_id is None, ) .populate_existing() ) @@ -764,7 +753,7 @@ def _add( management_type=None, auth_user_id=None, ): - new_row = self._repository.add( + return self._repository.add( session, auth_user_id=auth_user_id, public_address=public_address, @@ -774,8 +763,6 @@ def _add( network=network, ) - return new_row - # add = with_db_session(ro=False)(_add) add = _add @@ -813,10 +800,7 @@ def get_by_public_address(self, session, public_address, network=None, is_active row = self._repository.get_by(session, filters=filters, allow_inactive=not is_active) - if not row: - return None - - return one(row) + return one(row) if row else None # @with_db_session(ro=True) def get_by_auth_user_id( @@ -857,10 +841,7 @@ def get_by_auth_user_id( session, filters=filters, join_list=join_list, allow_inactive=not is_active ) - if not rows: - return [] - - return rows + return rows or [] def _update_by_id(self, session, model_id, **kwargs): self._repository.update(session, model_id, **kwargs) diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index 3d6b48e..7de2b0c 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -67,11 +67,7 @@ def get_by( join_list = join_list or () - if order_by_clause is not None: - order_by_clause = (order_by_clause,) - else: - order_by_clause = () - + order_by_clause = (order_by_clause, ) if order_by_clause is not None else () return self.repo.select( session, conditions=filters, @@ -132,9 +128,7 @@ def count_by( else: selectables = [label("count", func.count(self.model.id))] - for group in group_by: - selectables.append(group.expression) - + selectables.extend(group.expression for group in group_by) result = self.repo.select(session, selectables, conditions=filters, group_by=group_by) return result.all() @@ -181,15 +175,15 @@ def yield_by_chunk( ): filters = filters or () join_list = join_list or () - results = self.repo.select( + yield from self.repo.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[ + selectinload(getattr(self.model, attr)) for attr in join_list + ], include_inactive=allow_inactive, yield_by_chunk=chunk_size, ) - for result in results: - yield result class PydanticScalarResult(ScalarResult): @@ -245,9 +239,7 @@ def insert( create_data = create_schema.dict() result = super().insert(session, create_data) - if sqla_model: - return result - return self.schema.from_orm(result) + return result if sqla_model else self.schema.from_orm(result) def update( self, @@ -267,9 +259,7 @@ def update( session.add(existing) session.flush() session.refresh(existing) - if sqla_model: - return existing - return self.schema.from_orm(existing) + return existing if sqla_model else self.schema.from_orm(existing) def get( self, @@ -292,9 +282,7 @@ def get( if row is None: return - if sqla_model: - return row - return self.schema.from_orm(row) + return row if sqla_model else self.schema.from_orm(row) def select( self, @@ -328,6 +316,4 @@ def select( include_inactive, yield_by_chunk, ) - if sqla_model: - return result - return PydanticScalarResult(result, self.schema) + return result if sqla_model else PydanticScalarResult(result, self.schema) diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index e0304b7..b1f564a 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -37,7 +37,4 @@ class ResponseWrapper(BaseSchema): @validator("status") def set_status_by_error_code(cls, v, values): - error_code = values.get("error_code") - if error_code: - return "failed" - return "ok" + return "failed" if (error_code := values.get("error_code")) else "ok" diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py index d4d33d1..0a8582f 100644 --- a/src/quart_sqlalchemy/sim/util.py +++ b/src/quart_sqlalchemy/sim/util.py @@ -64,7 +64,7 @@ def __lt__(self, other): return False def __hash__(self): - return hash(tuple([self._encoded_id, self._decoded_id])) + return hash((self._encoded_id, self._decoded_id)) def __str__(self): return "{encoded_id}".format(encoded_id=self._encoded_id) @@ -88,10 +88,7 @@ def encode(self): return self._encoded_id def _decode(self, value): - if isinstance(value, int): - return value - else: - return self.hashids.decode(value)[0] + return value if isinstance(value, int) else self.hashids.decode(value)[0] def decode(self): return self._decoded_id diff --git a/src/quart_sqlalchemy/sqla.py b/src/quart_sqlalchemy/sqla.py index 56b8b60..7909235 100644 --- a/src/quart_sqlalchemy/sqla.py +++ b/src/quart_sqlalchemy/sqla.py @@ -43,7 +43,7 @@ class Model(self.config.model_class, sa.orm.DeclarativeBase): if base_class is Model: continue base_map = getattr(base_class, "type_annotation_map", {}).copy() - type_annotation_map.update(base_map) + type_annotation_map |= base_map Model.registry.type_annotation_map.update(type_annotation_map) self.Model = Model