Skip to content

Commit

Permalink
Merge pull request #99 from stuartcampbell/main
Browse files Browse the repository at this point in the history
Fix API key generation
  • Loading branch information
stuartcampbell authored Jun 7, 2024
2 parents ccc9001 + d461398 commit 5d392be
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 16 deletions.
39 changes: 35 additions & 4 deletions src/nsls2api/api/v1/admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
validate_admin_role,
generate_api_key,
)
from nsls2api.models.apikeys import ApiUser
from nsls2api.models.apikeys import ApiUser, ApiUserRole, ApiUserResponseModel, ApiUserType
from nsls2api.models.slack_models import SlackChannelCreationResponseModel
from nsls2api.services import beamline_service, proposal_service, slack_service

Expand Down Expand Up @@ -42,15 +42,15 @@ async def check_admin_validation(
return admin_user.username


@router.post("/admin/generate_api_key/{username}")
async def generate_user_apikey(username: str):
@router.post("/admin/generate-api-key/{username}")
async def generate_user_apikey(username: str, usertype: ApiUserType = ApiUserType.user):
"""
Generate an API key for a given username.
:param username: The username for which to generate the API key.
:return: The generated API key.
"""
return await generate_api_key(username)
return await generate_api_key(username, usertype=usertype)


@router.post("/admin/proposal/generate-test")
Expand Down Expand Up @@ -148,3 +148,34 @@ async def create_slack_channel(proposal_id: str) -> SlackChannelCreationResponse
)

return response_model


@router.put("/admin/user/{username}/role/{role}")
async def update_user_role(username: str, role: ApiUserRole) -> ApiUserResponseModel:
"""
Update the role of a user.
:param username: The username of the user to update.
:param role: The new role for the user.
:return: The updated user object.
"""
user = await ApiUser.find_one(ApiUser.username==username)
if user is None:
raise HTTPException(
status_code=fastapi.status.HTTP_404_NOT_FOUND,
detail=f"User {username} not found",
)

user.role = role
await user.save()

response = ApiUserResponseModel(
id=user.id,
username=user.username,
type=user.type,
role=user.role,
created_on=user.created_on,
last_updated=user.last_updated,
)

return response
1 change: 1 addition & 0 deletions src/nsls2api/api/v1/beamline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from nsls2api.models.beamlines import (
Beamline,
BeamlineService,
Detector,
DetectorList,
DirectoryList,
)
Expand Down
30 changes: 22 additions & 8 deletions src/nsls2api/infrastructure/security.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import calendar
import datetime
import enum
import logging
import secrets
from typing import Optional

Expand All @@ -12,6 +11,7 @@
from pydantic_settings import BaseSettings

from nsls2api.infrastructure.config import get_settings
from nsls2api.infrastructure.logging import logger
from nsls2api.models.apikeys import ApiKey, ApiUser, ApiUserType, ApiUserRole

TOKEN_BYTE_LENGTH = 32
Expand All @@ -20,8 +20,6 @@
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
api_key_query = APIKeyQuery(name="api_key", auto_error=False)

logger = logging.getLogger(__name__)


def hash_api_key(api_key):
return crypto.hash(api_key)
Expand All @@ -41,14 +39,14 @@ async def get_api_key(
return None


async def generate_api_key(username: str):
async def generate_api_key(username: str, usertype=ApiUserType.user):
try:
# Is there an API user for this key to be associated with
user = await ApiUser.find_one(ApiUser.username == username)
# If not create one
if not user:
print("No user found - creating user principal")
user = ApiUser(username=username, type=ApiUserType.user)
user = ApiUser(username=username, type=usertype)
await user.save(link_rule=WriteRules.WRITE)

# Actually generate the api key and add a readable prefix
Expand All @@ -69,19 +67,35 @@ async def generate_api_key(username: str):
expires_after=None,
)

# user.user_api_keys.append(new_key)
await user.update(link_rule=WriteRules.WRITE)
old_keys = await ApiKey.find(ApiKey.username == username).to_list()

await new_key.save(link_rule=WriteRules.WRITE)

# Now that we have saved a new key for this user, we should invalidate any other keys
for old_key in old_keys:
if old_key.valid:
logger.info(f"Invalidating old key: {old_key.secret_key}")
old_key.valid = False
await old_key.save(link_rule=WriteRules.WRITE)

return {"key:": secret_key}

except Exception as e:
print(e)
logger.exception(e)
raise e


async def set_user_role(username: str, role: ApiUserRole):
user = await ApiUser.find_one(ApiUser.username == username)
if user is None:
raise LookupError(f"Could not find a user with the username: {username}")

user.role = role
await user.save(link_rule=WriteRules.WRITE)

return user


async def lookup_api_key(token: str) -> ApiKey:
"""
:param token: The token used for API key lookup
Expand Down
18 changes: 14 additions & 4 deletions src/nsls2api/models/apikeys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
import enum
from enum import StrEnum
from typing import Optional, List
from uuid import UUID, uuid4

Expand All @@ -10,16 +10,24 @@
from pydantic import Field


class ApiUserType(str, enum.Enum):
class ApiUserType(StrEnum):
user = "user"
service = "service"


class ApiUserRole(str, enum.Enum):
class ApiUserRole(StrEnum):
user = "user"
staff = "staff"
admin = "admin"

class ApiUserResponseModel(pydantic.BaseModel):
id: UUID
username: str
type: ApiUserType
role: ApiUserRole
created_on: datetime.datetime
last_updated: datetime.datetime


class ApiUser(beanie.Document):
id: UUID = Field(default_factory=uuid4)
Expand Down Expand Up @@ -60,7 +68,9 @@ class ApiKey(beanie.Document):
user: Link[ApiUser]
username: str
first_eight: pydantic.constr(min_length=8, max_length=8)
secret_key: str # TODO: After development - we will not be storing this one in the database
secret_key: (
str # TODO: After development - we will not be storing this one in the database
)
hashed_key: str
note: Optional[str] = ""
# scopes: Optional[list[str]] = pydantic.Field(..., example=["inherit"])
Expand Down

0 comments on commit 5d392be

Please sign in to comment.