From e543eb302fea06718e06737cd104e020d3494635 Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Wed, 17 Jan 2024 16:10:40 +0900 Subject: [PATCH] Change deprecated sqlmodel method Use session.exec instead of session.execute --- fastapi_users_db_sqlmodel/__init__.py | 10 +++++----- fastapi_users_db_sqlmodel/access_token.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 695c5e2..57195bf 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -5,9 +5,9 @@ from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import ID, OAP, UP from pydantic import UUID4, EmailStr -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select +from sqlmodel.ext.asyncio.session import AsyncSession __version__ = "0.3.0" @@ -174,11 +174,11 @@ async def get_by_email(self, email: str) -> Optional[UP]: statement = select(self.user_model).where( # type: ignore func.lower(self.user_model.email) == func.lower(email) ) - results = await self.session.execute(statement) + results = await self.session.exec(statement) object = results.first() if object is None: return None - return object[0] + return object async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: """Get a single user by OAuth account id.""" @@ -190,10 +190,10 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP .where(self.oauth_account_model.account_id == account_id) .options(selectinload(self.oauth_account_model.user)) # type: ignore ) - results = await self.session.execute(statement) + results = await self.session.exec(statement) oauth_account = results.first() if oauth_account: - user = oauth_account[0].user # type: ignore + user = oauth_account.user # type: ignore return user return None diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py index 8a4519e..ce1722c 100644 --- a/fastapi_users_db_sqlmodel/access_token.py +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -4,8 +4,8 @@ from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase from pydantic import UUID4 from sqlalchemy import Column, types -from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import Field, Session, SQLModel, select +from sqlmodel.ext.asyncio.session import AsyncSession from fastapi_users_db_sqlmodel.generics import TIMESTAMPAware, now_utc @@ -49,11 +49,11 @@ async def get_by_token( if max_age is not None: statement = statement.where(self.access_token_model.created_at >= max_age) - results = self.session.execute(statement) + results = self.session.exec(statement) access_token = results.first() if access_token is None: return None - return access_token[0] + return access_token async def create(self, create_dict: Dict[str, Any]) -> AP: access_token = self.access_token_model(**create_dict) @@ -96,11 +96,11 @@ async def get_by_token( if max_age is not None: statement = statement.where(self.access_token_model.created_at >= max_age) - results = await self.session.execute(statement) + results = await self.session.exec(statement) access_token = results.first() if access_token is None: return None - return access_token[0] + return access_token async def create(self, create_dict: Dict[str, Any]) -> AP: access_token = self.access_token_model(**create_dict)