Skip to content

Commit

Permalink
refactor+test: Improve testing for auth config dep (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt authored Jan 9, 2025
1 parent 628b900 commit d6f776f
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 20 deletions.
78 changes: 78 additions & 0 deletions tests/unit/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from fastapi import HTTPException, status
from pytest_mock import MockerFixture

from tracecat.auth.dependencies import require_auth_type_enabled, verify_auth_type
from tracecat.auth.enums import AuthType


@pytest.mark.anyio
async def test_verify_auth_type_invalid_type():
"""Test that invalid auth types raise ValueError."""
with pytest.raises(ValueError, match="Invalid auth type"):
await require_auth_type_enabled("invalid_type") # type: ignore


@pytest.mark.parametrize(
"target_type,allowed_types",
[
pytest.param(
AuthType.BASIC,
[],
id="basic_auth",
),
pytest.param(
AuthType.SAML,
[AuthType.GOOGLE_OAUTH, AuthType.BASIC],
id="saml_auth",
),
],
)
@pytest.mark.anyio
async def test_verify_auth_type_not_allowed(
mocker: MockerFixture, target_type: AuthType, allowed_types: list[AuthType]
):
"""Test that unauthorized auth types raise HTTPException."""
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", allowed_types)

with pytest.raises(HTTPException) as exc:
await verify_auth_type(target_type)

assert exc.value.status_code == status.HTTP_403_FORBIDDEN
assert exc.value.detail == "Auth type not allowed"


@pytest.mark.anyio
async def test_verify_auth_type_setting_disabled(mocker: MockerFixture):
"""Test that disabled auth types raise HTTPException."""
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC])
mocker.patch("tracecat.auth.dependencies.get_setting", return_value=False)

with pytest.raises(HTTPException) as exc:
await verify_auth_type(AuthType.BASIC)

assert exc.value.status_code == status.HTTP_403_FORBIDDEN
assert exc.value.detail == f"Auth type {AuthType.BASIC.value} is not enabled"


@pytest.mark.anyio
async def test_verify_auth_type_invalid_setting(mocker: MockerFixture):
"""Test that invalid settings raise HTTPException."""
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC])
mocker.patch("tracecat.auth.dependencies.get_setting", return_value=None)

with pytest.raises(HTTPException) as exc:
await verify_auth_type(AuthType.BASIC)

assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert exc.value.detail == "Invalid setting configuration"


@pytest.mark.anyio
async def test_verify_auth_type_success(mocker: MockerFixture):
"""Test successful auth type verification."""
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC])
mocker.patch("tracecat.auth.dependencies.get_setting", return_value=True)

# Should not raise any exceptions
await verify_auth_type(AuthType.BASIC)
62 changes: 42 additions & 20 deletions tracecat/auth/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,53 @@
"""


async def verify_auth_type(auth_type: AuthType) -> None:
"""Verify if an auth type is enabled and properly configured.
Args:
auth_type: The authentication type to verify
Raises:
HTTPException: If the auth type is not allowed or not enabled
ValueError: If the auth type is invalid
"""

# 1. Check that this auth type is allowed
if auth_type not in config.TRACECAT__AUTH_TYPES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Auth type not allowed",
)

# 2. Check that the setting is enabled
key = AUTH_TYPE_TO_SETTING_KEY[auth_type]
setting = await get_setting(key=key, role=bootstrap_role())
if setting is None or not isinstance(setting, bool):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid setting configuration",
)
if not setting:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Auth type {auth_type} is not enabled",
)


def require_auth_type_enabled(auth_type: AuthType) -> Any:
"""FastAPI dependency to check if an auth type is enabled."""
"""FastAPI dependency to check if an auth type is enabled.
Args:
auth_type: The authentication type to check
Returns:
FastAPI dependency that verifies the auth type
"""

if auth_type not in AUTH_TYPE_TO_SETTING_KEY:
raise ValueError(f"Invalid auth type: {auth_type}")

async def _check_auth_type_enabled() -> None:
# 1. Check that this auth type is allowed
if auth_type not in config.TRACECAT__AUTH_TYPES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Auth type not allowed",
)
# 2. Check that the setting is enabled
key = AUTH_TYPE_TO_SETTING_KEY[auth_type]
setting = await get_setting(key=key, role=bootstrap_role())
if setting is None or not isinstance(setting, bool):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid setting configuration",
)
if not setting:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Auth type {auth_type} is not enabled",
)
await verify_auth_type(auth_type)

return Depends(_check_auth_type_enabled)

0 comments on commit d6f776f

Please sign in to comment.