Skip to content

Commit

Permalink
feat(app)!: Migrate configs to use settings service (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt authored Jan 9, 2025
1 parent 9481d6a commit 00df363
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 51 deletions.
29 changes: 22 additions & 7 deletions tests/unit/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,32 @@ def test_parse_git_url(url: str, expected: GitUrl):
@pytest.mark.parametrize(
"url",
[
"git+ssh://[email protected]/tracecat-dev/[email protected]",
"git+ssh://[email protected]/tracecat-dev/[email protected]",
# Adding invalid cases from old test
"https://github.com/org/repo",
"git+ssh://[email protected]/org",
"git+ssh://[email protected]/org/repo@branch/extra",
pytest.param(
"git+ssh://[email protected]/tracecat-dev/[email protected]",
id="Invalid host domain tracecat.com",
),
pytest.param(
"git+ssh://[email protected]/tracecat-dev/[email protected]",
id="Invalid host domain git.com",
),
pytest.param(
"https://github.com/org/repo",
id="Invalid URL scheme - must be git+ssh",
),
pytest.param(
"git+ssh://[email protected]/org",
id="Missing repository name",
),
pytest.param(
"git+ssh://[email protected]/org/repo@branch/extra",
id="Invalid branch format with extra path component",
),
],
)
def test_parse_git_url_invalid(url: str):
allowed_domains = {"github.com", "gitlab.com"}
with pytest.raises(ValueError):
parse_git_url(url)
parse_git_url(url, allowed_domains=allowed_domains)


def test_construct_template_action():
Expand Down
54 changes: 27 additions & 27 deletions tracecat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
generic_exception_handler,
tracecat_exception_handler,
)
from tracecat.auth.dependencies import require_auth_type_enabled
from tracecat.auth.enums import AuthType
from tracecat.auth.models import UserCreate, UserRead, UserUpdate
from tracecat.auth.router import router as users_router
from tracecat.auth.saml import router as saml_router
from tracecat.auth.users import (
FastAPIUsersException,
InvalidDomainException,
Expand Down Expand Up @@ -186,33 +188,31 @@ def create_app(**kwargs) -> FastAPI:
tags=["auth"],
)

if AuthType.GOOGLE_OAUTH in config.TRACECAT__AUTH_TYPES:
oauth_client = GoogleOAuth2(
client_id=config.OAUTH_CLIENT_ID, client_secret=config.OAUTH_CLIENT_SECRET
)
# This is the frontend URL that the user will be redirected to after authenticating
redirect_url = f"{config.TRACECAT__PUBLIC_APP_URL}/auth/oauth/callback"
logger.info("OAuth redirect URL", url=redirect_url)
app.include_router(
fastapi_users.get_oauth_router(
oauth_client,
auth_backend,
config.USER_AUTH_SECRET,
# XXX(security): See https://fastapi-users.github.io/fastapi-users/13.0/configuration/oauth/#existing-account-association
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=redirect_url,
),
prefix="/auth/oauth",
tags=["auth"],
)

if AuthType.SAML in config.TRACECAT__AUTH_TYPES:
from tracecat.auth.saml import router as saml_router

logger.info("SAML auth type enabled")
app.include_router(saml_router)
oauth_client = GoogleOAuth2(
client_id=config.OAUTH_CLIENT_ID, client_secret=config.OAUTH_CLIENT_SECRET
)
# This is the frontend URL that the user will be redirected to after authenticating
redirect_url = f"{config.TRACECAT__PUBLIC_APP_URL}/auth/oauth/callback"
logger.info("OAuth redirect URL", url=redirect_url)
app.include_router(
fastapi_users.get_oauth_router(
oauth_client,
auth_backend,
config.USER_AUTH_SECRET,
# XXX(security): See https://fastapi-users.github.io/fastapi-users/13.0/configuration/oauth/#existing-account-association
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=redirect_url,
),
prefix="/auth/oauth",
tags=["auth"],
dependencies=[require_auth_type_enabled(AuthType.GOOGLE_OAUTH)],
)
app.include_router(
saml_router,
dependencies=[require_auth_type_enabled(AuthType.SAML)],
)

if AuthType.BASIC not in config.TRACECAT__AUTH_TYPES:
# Need basic auth router for `logout` endpoint
Expand Down
6 changes: 4 additions & 2 deletions tracecat/api/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tracecat.middleware import RequestLoggingMiddleware
from tracecat.registry.repositories.service import RegistryReposService
from tracecat.registry.repository import Repository
from tracecat.settings.service import get_setting
from tracecat.types.exceptions import TracecatException


Expand Down Expand Up @@ -45,11 +46,12 @@ async def setup_custom_remote_repository():
2. If it doesn't exist, create it
3. If it does exist, sync it
"""
url = config.TRACECAT__REMOTE_REPOSITORY_URL
role = bootstrap_role()
url = await get_setting("git_repo_url", role=role)
if not url:
logger.info("Remote repository URL not set, skipping")
return
role = bootstrap_role()
logger.info("Remote repository URL found", url=url)
async with RegistryReposService.with_session(role) as service:
db_repo = await service.get_repository(url)
# If it doesn't exist, do nothing
Expand Down
42 changes: 41 additions & 1 deletion tracecat/auth/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Annotated
from typing import Annotated, Any

from fastapi import Depends, HTTPException, status

from tracecat import config
from tracecat.api.common import bootstrap_role
from tracecat.auth.credentials import RoleACL
from tracecat.auth.enums import AuthType
from tracecat.logger import logger
from tracecat.settings.constants import AUTH_TYPE_TO_SETTING_KEY
from tracecat.settings.service import get_setting
from tracecat.types.auth import Role

WorkspaceUserRole = Annotated[
Expand All @@ -11,3 +19,35 @@
Sets the `ctx_role` context variable.
"""


def require_auth_type_enabled(auth_type: AuthType) -> Any:
"""FastAPI dependency to check if an auth type is enabled."""

if auth_type not in AUTH_TYPE_TO_SETTING_KEY:
raise ValueError(f"Invalid auth type: {auth_type}")

async def _check_auth_type_enabled() -> None:
# 1. Check that this auth type is allowed
if auth_type not in config.TRACECAT__AUTH_TYPES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Auth type not allowed",
)
# 2. Check that the setting is enabled
key = AUTH_TYPE_TO_SETTING_KEY[auth_type]
logger.warning("Checking auth type enabled", key=key)
setting = await get_setting(key=key, role=bootstrap_role())
logger.warning("Setting", setting=setting)
if setting is None or not isinstance(setting, bool):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid setting configuration",
)
if not setting:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Auth type {auth_type} is not enabled",
)

return Depends(_check_auth_type_enabled)
25 changes: 17 additions & 8 deletions tracecat/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from collections.abc import AsyncGenerator, Sequence
from datetime import UTC, datetime
from typing import Annotated
from typing import Annotated, cast

from fastapi import APIRouter, Depends, Request, Response, status
from fastapi_users import (
Expand Down Expand Up @@ -34,6 +34,7 @@
from sqlmodel.ext.asyncio.session import AsyncSession as SQLModelAsyncSession

from tracecat import config
from tracecat.api.common import bootstrap_role
from tracecat.auth.models import UserCreate, UserRole, UserUpdate
from tracecat.db.adapter import (
SQLModelAccessTokenDatabaseAsync,
Expand All @@ -42,6 +43,7 @@
from tracecat.db.engine import get_async_session, get_async_session_context_manager
from tracecat.db.schemas import AccessToken, OAuthAccount, User
from tracecat.logger import logger
from tracecat.settings.service import get_setting


class InvalidDomainException(FastAPIUsersException):
Expand All @@ -55,13 +57,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
def __init__(self, user_db: SQLAlchemyUserDatabase) -> None:
super().__init__(user_db)
self.logger = logger.bind(unit="UserManager")
self.role = bootstrap_role()

async def validate_password(self, password: str, user: User) -> None:
if len(password) < config.TRACECAT__AUTH_MIN_PASSWORD_LENGTH:
raise InvalidPasswordException(
f"Password must be at least {config.TRACECAT__AUTH_MIN_PASSWORD_LENGTH} characters long"
)

async def validate_email(self, email: str) -> None:
allowed_domains = cast(
list[str] | None,
await get_setting("auth_allowed_email_domains", role=self.role),
)
validate_email(email=email, allowed_domains=allowed_domains)

async def oauth_callback(
self,
oauth_name: str,
Expand All @@ -75,7 +85,7 @@ async def oauth_callback(
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
validate_email(account_email)
await self.validate_email(account_email)
return await super().oauth_callback( # type: ignore
oauth_name,
access_token,
Expand All @@ -94,7 +104,7 @@ async def create(
safe: bool = False,
request: Request | None = None,
) -> User:
validate_email(email=user_create.email)
await self.validate_email(user_create.email)
return await super().create(user_create, safe, request)

async def on_after_login(
Expand Down Expand Up @@ -311,13 +321,12 @@ async def list_users(*, session: SQLModelAsyncSession) -> Sequence[User]:
return result.all()


def validate_email(email: EmailStr) -> None:
def validate_email(
email: EmailStr, *, allowed_domains: list[str] | None = None
) -> None:
# Safety: This is already a validated email, so we can split on the first @
_, domain = email.split("@", 1)
logger.info(f"Domain: {domain}")

if (
config.TRACECAT__AUTH_ALLOWED_DOMAINS
and domain not in config.TRACECAT__AUTH_ALLOWED_DOMAINS
):
if allowed_domains and domain not in allowed_domains:
raise InvalidDomainException(f"You cannot register with the domain {domain!r}")
19 changes: 13 additions & 6 deletions tracecat/registry/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from typing_extensions import Doc

from tracecat import config
from tracecat.config import TRACECAT__ALLOWED_GIT_DOMAINS
from tracecat.contexts import ctx_role
from tracecat.expressions.expectations import create_expectation_model
from tracecat.expressions.validation import TemplateValidator
Expand All @@ -48,6 +47,7 @@
from tracecat.registry.repositories.models import RegistryRepositoryCreate
from tracecat.registry.repositories.service import RegistryReposService
from tracecat.secrets.service import SecretsService
from tracecat.settings.service import get_setting
from tracecat.types.auth import Role
from tracecat.types.exceptions import RegistryError

Expand Down Expand Up @@ -246,8 +246,13 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
# Load from remote
logger.info("Loading UDFs from origin", origin=self._origin)

allowed_domains = cast(
set[str],
await get_setting("git_allowed_domains", role=self.role) or {"github.com"},
)

try:
git_url = parse_git_url(self._origin)
git_url = parse_git_url(self._origin, allowed_domains=allowed_domains)
host = git_url.host
org = git_url.org
repo_name = git_url.repo
Expand All @@ -263,7 +268,9 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
package_name=repo_name,
branch=branch,
)
package_name = config.TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME or repo_name
package_name = (
await get_setting("git_repo_package_name", role=self.role) or repo_name
)

cleaned_url = self.safe_remote_url(self._origin)
logger.debug("Cleaned URL", url=cleaned_url)
Expand Down Expand Up @@ -641,7 +648,7 @@ class GitUrl:
branch: str


def parse_git_url(url: str) -> GitUrl:
def parse_git_url(url: str, *, allowed_domains: set[str] | None = None) -> GitUrl:
"""
Parse a Git repository URL to extract organization, package name, and branch.
Handles Git SSH URLs with 'git+ssh' prefix and optional '@' for branch specification.
Expand All @@ -659,9 +666,9 @@ def parse_git_url(url: str) -> GitUrl:

if match := re.match(pattern, url):
host = match.group("host")
if host not in TRACECAT__ALLOWED_GIT_DOMAINS:
if allowed_domains and host not in allowed_domains:
raise ValueError(
f"Domain {host} not in allowed domains. Must be configured in TRACECAT__ALLOWED_GIT_DOMAINS."
f"Domain {host} not in allowed domains. Must be configured in `git_allowed_domains` organization setting."
)

return GitUrl(
Expand Down

0 comments on commit 00df363

Please sign in to comment.