Skip to content

Commit

Permalink
Merge pull request #15 from hotosm/auth
Browse files Browse the repository at this point in the history
Implement JWT auth in Google oauth
  • Loading branch information
nrjadkry authored Jul 1, 2024
2 parents 526b287 + 3bb79e3 commit 7e0acd6
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 36 deletions.
1 change: 1 addition & 0 deletions src/backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def assemble_db_connection(cls, v: Optional[str], info: ValidationInfo) -> Any:
S3_BUCKET_NAME: str = "dtm-data"
S3_DOWNLOAD_ROOT: Optional[str] = None

ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1 # 1 day
REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 day

Expand Down
23 changes: 14 additions & 9 deletions src/backend/app/projects/project_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@
)


@router.delete('/{project_id}', tags=["Projects"])
def delete_project_by_id(
project_id: int,
db: Session = Depends(database.get_db)
):
@router.delete("/{project_id}", tags=["Projects"])
def delete_project_by_id(project_id: int, db: Session = Depends(database.get_db)):
"""
Delete a project by its ID, along with all associated tasks.
Expand All @@ -41,20 +38,28 @@ def delete_project_by_id(
HTTPException: If the project is not found.
"""
# Query for the project
project = db.query(db_models.DbProject).filter(db_models.DbProject.id == project_id).first()
project = (
db.query(db_models.DbProject)
.filter(db_models.DbProject.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found.")

# Query and delete associated tasks
tasks = db.query(db_models.DbTask).filter(db_models.DbTask.project_id == project_id).all()
tasks = (
db.query(db_models.DbTask)
.filter(db_models.DbTask.project_id == project_id)
.all()
)
for task in tasks:
db.delete(task)

# Delete the project
db.delete(project)
db.commit()
return {"message": f"Project ID: {project_id} is deleted successfully."}


@router.post(
"/create_project", tags=["Projects"], response_model=project_schemas.ProjectOut
Expand Down
8 changes: 7 additions & 1 deletion src/backend/app/users/oauth_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from app.users.user_routes import router
from app.users.user_deps import init_google_auth, login_required
from app.users.user_schemas import AuthUser
from app.users import user_crud
from app.config import settings


if settings.DEBUG:
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"

Expand Down Expand Up @@ -39,7 +41,11 @@ async def callback(request: Request, google_auth=Depends(init_google_auth)):

callback_url = str(request.url)
access_token = google_auth.callback(callback_url).get("access_token")
return access_token

user_data = google_auth.deserialize_access_token(access_token)
access_token, refresh_token = user_crud.create_access_token(user_data)

return {"access_token": access_token, "refresh_token": refresh_token}


@router.get("/my-info/")
Expand Down
46 changes: 33 additions & 13 deletions src/backend/app/users/user_crud.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,58 @@
import time
import jwt
from app.config import settings
from datetime import datetime, timedelta
from typing import Any
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from app.db import db_models
from app.users.user_schemas import UserCreate
from sqlalchemy import text
from fastapi import HTTPException


pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

ALGORITHM = "HS256"

def create_access_token(subject: str | Any):
expire = int(time.time()) + settings.ACCESS_TOKEN_EXPIRE_MINUTES
refresh_expire = int(time.time()) + settings.REFRESH_TOKEN_EXPIRE_MINUTES

def create_access_token(
subject: str | Any, expires_delta: timedelta, refresh_token_expiry: timedelta
):
expire = datetime.utcnow() + expires_delta
refresh_expire = datetime.utcnow() + refresh_token_expiry

to_encode_access_token = {"exp": expire, "sub": str(subject)}
to_encode_refresh_token = {"exp": refresh_expire, "sub": str(subject)}

# access token
subject["exp"] = expire
access_token = jwt.encode(
to_encode_access_token, settings.SECRET_KEY, algorithm=ALGORITHM
subject, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)

# refresh token
subject["exp"] = refresh_expire
refresh_token = jwt.encode(
to_encode_refresh_token, settings.SECRET_KEY, algorithm=ALGORITHM
subject, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)

return access_token, refresh_token


def verify_token(token: str):
"""Verifies the access token and returns the payload if valid.
Args:
token (str): The access token to be verified.
Returns:
dict: The payload of the access token if verification is successful.
Raises:
HTTPException: If the token has expired or credentials could not be validated.
"""
secret_key = settings.SECRET_KEY
try:
return jwt.decode(token, str(secret_key), algorithms=[settings.ALGORITHM])
except jwt.ExpiredSignatureError as e:
raise HTTPException(status_code=401, detail="Token has expired") from e
except Exception as e:
raise HTTPException(status_code=401, detail="Could not validate token") from e


def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)

Expand Down
11 changes: 6 additions & 5 deletions src/backend/app/users/user_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,17 @@ async def login_required(
) -> AuthUser:
"""Dependency to inject into endpoints requiring login."""

google_auth = await init_google_auth()
if not access_token:
raise HTTPException(status_code=401, detail="No access token provided")

if not access_token:
raise HTTPException(status_code=401, detail="No access token provided")

try:
google_user = google_auth.deserialize_access_token(access_token)
except ValueError as e:
user = user_crud.verify_token(access_token)
except HTTPException as e:
log.error(e)
log.error("Failed to deserialise access token")
log.error("Failed to verify access token")
raise HTTPException(status_code=401, detail="Access token not valid") from e

return AuthUser(**google_user)
return AuthUser(**user)
10 changes: 2 additions & 8 deletions src/backend/app/users/user_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from app.config import settings
from app.users import user_crud
from app.db import database
from app.config import settings

router = APIRouter(
prefix=f"{settings.API_PREFIX}/users",
Expand All @@ -32,14 +31,9 @@ def login_access_token(
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)

access_token, refresh_token = user_crud.create_access_token(
user.id,
expires_delta=access_token_expires,
refresh_token_expiry=refresh_token_expires,
)
access_token, refresh_token = user_crud.create_access_token(user.id)

return Token(access_token=access_token, refresh_token=refresh_token)


Expand Down

0 comments on commit 7e0acd6

Please sign in to comment.