From d07c56422a51d3309c0270a07bc08030a9017384 Mon Sep 17 00:00:00 2001 From: Jacky Su Date: Tue, 28 Mar 2023 22:49:06 +0100 Subject: [PATCH] add comments, add expires_at to token --- api.py | 33 +++++++++++++++++---------------- app.py | 4 ++++ database/database.env | 1 + db.py | 8 ++++++-- models.py | 1 + requirements.txt | Bin 318 -> 314 bytes 6 files changed, 29 insertions(+), 18 deletions(-) create mode 100644 database/database.env diff --git a/api.py b/api.py index a9cfc79..fe49f89 100644 --- a/api.py +++ b/api.py @@ -36,11 +36,12 @@ DEFAULT_PAGE_SIZE = 30 MAX_PAGE_SIZE = 50 +# jinja2 template for server side render html # templates = Jinja2Templates(directory="templates") -# 这里是 tokenUrl,而不是 token_url,是为了和 OAuth2 规范统一 -# tokenUrl 是为了指定 OpenAPI 前端登录时的接口地址 -# OAuthPasswordBearer 把 Authorization Header 的 Bearer 取出来,然后传给 tokenUrl +# it is tokenUrl instead of kebab-case token_url to unify with OAuth2 scheme +# tokenUrl is to specify OpenAPI route address of token endpoint for frontend login +# OAuthPasswordBearer extracts Bearer from Authorization Header and send it to tokenUrl oauth2_bearer = OAuth2PasswordBearer(tokenUrl=API_PREFIX + "/user/token") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -51,6 +52,7 @@ async def startup(): await init_db() +# for further refactor the functions down below should be moved to a utils package def is_admin(user: User) -> bool: return True if user.role else False @@ -86,6 +88,7 @@ def _decode_token(token: str) -> dict: raise unauthorized_error("Could not validate credentials") +# warning: this function returns all the users without limit and offset async def get_all_users_db(session: AsyncSession = Depends( get_session)) -> List[User]: result = await session.execute(select(User)) @@ -108,6 +111,7 @@ async def get_user_with_email_db( # return [User(**user.__dict__) for user in users] +# get current user from token by decoding it async def get_current_user(token: str = Depends(oauth2_bearer), session: AsyncSession = Depends( get_session)) -> User: @@ -135,12 +139,7 @@ async def get_current_user(token: str = Depends(oauth2_bearer), return user -@router.on_event("startup") -async def on_startup(): - await init_db() - - -@router.post("/user/signup") +@router.post("/user/signup", response_model=Token) async def create_user(user: UserSignup, session: AsyncSession = Depends(get_session)): if not user.email or not re.match( @@ -176,7 +175,7 @@ async def create_user(user: UserSignup, "role": 0 }, expires=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) - response = {"access_token": token, "token_type": "bearer", "role": 0} + response = {"access_token": token, "token_type": "bearer", "role": 0, "expires_at": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)} return response @@ -287,7 +286,8 @@ async def login(form: OAuth2PasswordRequestForm = Depends(), response = { "access_token": token, "token_type": "bearer", - "role": user.role + "role": user.role, + "expires_at": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) } return response @@ -330,11 +330,12 @@ async def admin_get_all_users(response: Response, @router.get("/admin/projects", response_model=List[ProjectWithUserAndTags]) -async def admin_get_all_projects(response: Response, - per_page: int = DEFAULT_PAGE_SIZE, - page: int = DEFAULT_PAGE, - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user)): +async def admin_get_all_projects( + response: Response, + per_page: int = DEFAULT_PAGE_SIZE, + page: int = DEFAULT_PAGE, + session: AsyncSession = Depends(get_session), + current_user: User = Depends(get_current_user)): if not current_user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") diff --git a/app.py b/app.py index abaea2f..7f03a2e 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,8 @@ from fastapi.middleware.cors import CORSMiddleware import api + +# Create FastAPI app and add CORS middleware app = FastAPI() app.add_middleware( CORSMiddleware, @@ -12,9 +14,11 @@ allow_headers=["*"], ) +# Add API routes app.include_router(api.router) +# redirect to docs @app.get("/") async def index(): return RedirectResponse(url="/docs", status_code=302) diff --git a/database/database.env b/database/database.env new file mode 100644 index 0000000..97c3212 --- /dev/null +++ b/database/database.env @@ -0,0 +1 @@ +DATABASE_URL=sqlite+aiosqlite:///./database/database.db \ No newline at end of file diff --git a/db.py b/db.py index 26c5070..8b609da 100644 --- a/db.py +++ b/db.py @@ -5,18 +5,22 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker - -DATABASE_URL = 'sqlite+aiosqlite:///./database/database.db' +from dotenv import load_dotenv +# get the database url from the environment variable +load_dotenv("./database/database.env") +DATABASE_URL = os.getenv("DATABASE_URL") engine = create_async_engine(DATABASE_URL, echo=True, future=True) +# initialize db and create tables if they don't exist async def init_db(): async with engine.begin() as conn: # await conn.run_sync(SQLModel.metadata.drop_all) await conn.run_sync(SQLModel.metadata.create_all) +# create a new session and return it when called async def get_session() -> AsyncSession: async_session = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False diff --git a/models.py b/models.py index 21c78d9..29c04ec 100644 --- a/models.py +++ b/models.py @@ -44,6 +44,7 @@ class Token(SQLModel): access_token: str token_type: str role: int + expires_at: datetime class CategoryBase(SQLModel): diff --git a/requirements.txt b/requirements.txt index 5f0301671dfc8d23e138b6ba0aafe9780c7eec5c..198583815abdfa6779f4ec87473cd58c01f5167b 100644 GIT binary patch delta 34 mcmdnTw2Ntj0;5y`LnT8ALk2@WLmq=JLkf^CVMqnC%NPKc)Cb1^ delta 38 pcmdnRw2x_n0;60SLn1>lLkW;9V8~?9Wyk@Nxj