From 7f17307e7deca08e7115e2995e24ac1e389bb7df Mon Sep 17 00:00:00 2001 From: 8thgencore Date: Sat, 9 Sep 2023 17:54:46 +0300 Subject: [PATCH 1/2] feat: migarte from passlib to bcrypt --- backend/app/app/api/deps.py | 41 ++++++++++------ backend/app/app/api/v1/endpoints/login.py | 45 ++++++++++------- backend/app/app/core/security.py | 49 ++++++++++++++----- backend/app/app/main.py | 59 ++++++++++++++--------- backend/app/pyproject.toml | 4 +- 5 files changed, 128 insertions(+), 70 deletions(-) diff --git a/backend/app/app/api/deps.py b/backend/app/app/api/deps.py index 5278f6e..06fd935 100644 --- a/backend/app/app/api/deps.py +++ b/backend/app/app/api/deps.py @@ -1,21 +1,21 @@ from collections.abc import AsyncGenerator from typing import Callable + +import redis.asyncio as aioredis from fastapi import Depends, HTTPException, status -from app.utils.token import get_valid_tokens -from app.utils.minio_client import MinioClient from fastapi.security import OAuth2PasswordBearer -from jose import jwt -from app.models.user_model import User -from pydantic import ValidationError +from jwt import DecodeError, ExpiredSignatureError, MissingRequiredClaimError +from redis.asyncio import Redis +from sqlmodel.ext.asyncio.session import AsyncSession + from app import crud -from app.core import security from app.core.config import settings +from app.core.security import decode_token from app.db.session import SessionLocal, SessionLocalCelery -from sqlmodel.ext.asyncio.session import AsyncSession +from app.models.user_model import User from app.schemas.common_schema import IMetaGeneral, TokenType -import redis.asyncio as aioredis -from redis.asyncio import Redis - +from app.utils.minio_client import MinioClient +from app.utils.token import get_valid_tokens reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token" @@ -49,23 +49,32 @@ async def get_general_meta() -> IMetaGeneral: def get_current_user(required_roles: list[str] = None) -> Callable[[], User]: async def current_user( - token: str = Depends(reusable_oauth2), + access_token: str = Depends(reusable_oauth2), redis_client: Redis = Depends(get_redis_client), ) -> User: try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + payload = decode_token(access_token) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Your token has expired. Please log in again.", ) - except (jwt.JWTError, ValidationError): + except DecodeError: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Could not validate credentials", + detail="Error when decoding the token. Please check your request.", ) + except MissingRequiredClaimError: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="There is no required field in your token. Please contact the administrator.", + ) + user_id = payload["sub"] valid_access_tokens = await get_valid_tokens( redis_client, user_id, TokenType.ACCESS ) - if valid_access_tokens and token not in valid_access_tokens: + if valid_access_tokens and access_token not in valid_access_tokens: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials", diff --git a/backend/app/app/api/v1/endpoints/login.py b/backend/app/app/api/v1/endpoints/login.py index a96b027..6325df8 100644 --- a/backend/app/app/api/v1/endpoints/login.py +++ b/backend/app/app/api/v1/endpoints/login.py @@ -1,24 +1,22 @@ from datetime import timedelta -from fastapi import APIRouter, Body, Depends, HTTPException -from redis.asyncio import Redis -from app.utils.token import get_valid_tokens -from app.utils.token import delete_tokens -from app.utils.token import add_token_to_redis -from app.core.security import get_password_hash -from app.core.security import verify_password -from app.models.user_model import User -from app.api.deps import get_redis_client + +from fastapi import APIRouter, Body, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm -from jose import jwt +from jwt import DecodeError, ExpiredSignatureError, MissingRequiredClaimError from pydantic import EmailStr -from pydantic import ValidationError +from redis.asyncio import Redis + from app import crud from app.api import deps +from app.api.deps import get_redis_client from app.core import security from app.core.config import settings -from app.schemas.common_schema import TokenType, IMetaGeneral -from app.schemas.token_schema import TokenRead, Token, RefreshToken +from app.core.security import decode_token, get_password_hash, verify_password +from app.models.user_model import User +from app.schemas.common_schema import IMetaGeneral, TokenType from app.schemas.response_schema import IPostResponseBase, create_response +from app.schemas.token_schema import RefreshToken, Token, TokenRead +from app.utils.token import add_token_to_redis, delete_tokens, get_valid_tokens router = APIRouter() @@ -147,11 +145,22 @@ async def get_new_access_token( Gets a new access token using the refresh token for future requests """ try: - payload = jwt.decode( - body.refresh_token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + payload = decode_token(body.refresh_token) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Your token has expired. Please log in again.", + ) + except DecodeError: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Error when decoding the token. Please check your request.", + ) + except MissingRequiredClaimError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="There is no required field in your token. Please contact the administrator.", ) - except (jwt.JWTError, ValidationError): - raise HTTPException(status_code=403, detail="Refresh token invalid") if payload["type"] == "refresh": user_id = payload["sub"] @@ -163,7 +172,7 @@ async def get_new_access_token( access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) user = await crud.user.get(id=user_id) - if getattr(user, "is_active"): + if user.is_active: access_token = security.create_access_token( payload["sub"], expires_delta=access_token_expires ) diff --git a/backend/app/app/core/security.py b/backend/app/app/core/security.py index 53f1703..bd1de9b 100644 --- a/backend/app/app/core/security.py +++ b/backend/app/app/core/security.py @@ -1,14 +1,15 @@ from datetime import datetime, timedelta from typing import Any + +import bcrypt +import jwt from cryptography.fernet import Fernet -from jose import jwt -from passlib.context import CryptContext + from app.core.config import settings -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") fernet = Fernet(str.encode(settings.ENCRYPT_KEY)) -ALGORITHM = "HS256" +JWT_ALGORITHM = "HS256" def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str: @@ -19,8 +20,12 @@ def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) to_encode = {"exp": expire, "sub": str(subject), "type": "access"} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + + return jwt.encode( + payload=to_encode, + key=settings.ENCRYPT_KEY, + algorithm=JWT_ALGORITHM, + ) def create_refresh_token(subject: str | Any, expires_delta: timedelta = None) -> str: @@ -31,16 +36,36 @@ def create_refresh_token(subject: str | Any, expires_delta: timedelta = None) -> minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES ) to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + return jwt.encode( + payload=to_encode, + key=settings.ENCRYPT_KEY, + algorithm=JWT_ALGORITHM, + ) + + +def decode_token(token: str) -> dict[str, Any]: + return jwt.decode( + jwt=token, + key=settings.ENCRYPT_KEY, + algorithms=[JWT_ALGORITHM], + ) + + +def verify_password(plain_password: str | bytes, hashed_password: str | bytes) -> bool: + if isinstance(plain_password, str): + plain_password = plain_password.encode() + if isinstance(hashed_password, str): + hashed_password = hashed_password.encode() + + return bcrypt.checkpw(plain_password, hashed_password) -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) +def get_password_hash(plain_password: str | bytes) -> str: + if isinstance(plain_password, str): + plain_password = plain_password.encode() -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) + return bcrypt.hashpw(plain_password, bcrypt.gensalt()).decode() def get_data_encrypt(data) -> str: diff --git a/backend/app/app/main.py b/backend/app/app/main.py index 5959c16..71ac890 100644 --- a/backend/app/app/main.py +++ b/backend/app/app/main.py @@ -1,10 +1,9 @@ import gc import logging +from contextlib import asynccontextmanager from typing import Any from uuid import UUID, uuid4 -from app import crud -from app.schemas.common_schema import IChatResponse, IUserMessage -from app.utils.uuid6 import uuid7 + from fastapi import ( FastAPI, HTTPException, @@ -13,25 +12,27 @@ WebSocketDisconnect, status, ) -from app.core import security -from app.api.deps import get_redis_client -from fastapi_pagination import add_pagination -from pydantic import ValidationError -from starlette.middleware.cors import CORSMiddleware -from app.api.v1.api import api_router as api_router_v1 -from app.core.config import ModeEnum, settings +from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db from fastapi_cache import FastAPICache from fastapi_cache.backends.redis import RedisBackend -from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db -from contextlib import asynccontextmanager -from app.utils.fastapi_globals import g, GlobalsMiddleware -from transformers import pipeline from fastapi_limiter import FastAPILimiter -from jose import jwt from fastapi_limiter.depends import WebSocketRateLimiter +from fastapi_pagination import add_pagination +from jwt import DecodeError, ExpiredSignatureError, MissingRequiredClaimError from langchain.chat_models import ChatOpenAI from langchain.schema import HumanMessage from sqlalchemy.pool import NullPool, QueuePool +from starlette.middleware.cors import CORSMiddleware +from transformers import pipeline + +from app import crud +from app.api.deps import get_redis_client +from app.api.v1.api import api_router as api_router_v1 +from app.core.config import ModeEnum, settings +from app.core.security import decode_token +from app.schemas.common_schema import IChatResponse, IUserMessage +from app.utils.fastapi_globals import GlobalsMiddleware, g +from app.utils.uuid6 import uuid7 async def user_id_identifier(request: Request): @@ -45,16 +46,25 @@ async def user_id_identifier(request: Request): if len(header_parts) == 2 and header_parts[0].lower() == "bearer": token = header_parts[1] try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + payload = decode_token(token) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Your token has expired. Please log in again.", ) - except (jwt.JWTError, ValidationError): + except DecodeError: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Could not validate credentials", + detail="Error when decoding the token. Please check your request.", ) + except MissingRequiredClaimError: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="There is no required field in your token. Please contact the administrator.", + ) + user_id = payload["sub"] - print("here2", user_id) + return user_id if request.scope["type"] == "websocket": @@ -65,7 +75,7 @@ async def user_id_identifier(request: Request): return forwarded.split(",")[0] client = request.client - ip = getattr(client, "host", "0.0.0.0") + ip = getattr(client, "host", "0.0.0.0") return ip + ":" + request.scope["path"] @@ -134,7 +144,12 @@ class CustomException(Exception): code: str message: str - def __init__(self, http_code: int = 500, code: str | None = None, message: str = 'This is an error message'): + def __init__( + self, + http_code: int = 500, + code: str | None = None, + message: str = "This is an error message", + ): self.http_code = http_code self.code = code if code else str(self.http_code) self.message = message diff --git a/backend/app/pyproject.toml b/backend/app/pyproject.toml index 82ca7d2..8743f7f 100644 --- a/backend/app/pyproject.toml +++ b/backend/app/pyproject.toml @@ -44,9 +44,9 @@ alembic = "^1.10.2" asyncpg = "^0.27.0" fastapi = {extras = ["all"], version = "^0.95.2"} sqlmodel = "^0.0.8" -python-jose = "^3.3.0" cryptography = "^38.0.3" -passlib = "^1.7.4" +bcrypt = "^4.0.1" +pyjwt = { extras = ["crypto"], version = "^2.8.0" } SQLAlchemy-Utils = "^0.38.3" SQLAlchemy = "^1.4.40" fastapi-pagination = {extras = ["sqlalchemy"], version = "^0.11.4"} From 8ab49fcd7e2dc13895d9fc7cc5bd468e220076d2 Mon Sep 17 00:00:00 2001 From: 8thgencore Date: Sat, 9 Sep 2023 18:01:11 +0300 Subject: [PATCH 2/2] bump: cryptography version --- backend/app/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/pyproject.toml b/backend/app/pyproject.toml index 8743f7f..c1a488c 100644 --- a/backend/app/pyproject.toml +++ b/backend/app/pyproject.toml @@ -44,7 +44,7 @@ alembic = "^1.10.2" asyncpg = "^0.27.0" fastapi = {extras = ["all"], version = "^0.95.2"} sqlmodel = "^0.0.8" -cryptography = "^38.0.3" +cryptography = "^41.0.3" bcrypt = "^4.0.1" pyjwt = { extras = ["crypto"], version = "^2.8.0" } SQLAlchemy-Utils = "^0.38.3"