From 6feb5232d732bd3ed3b2beea07b413e65fbe1a98 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:28:50 +0000 Subject: [PATCH 01/13] feat(engine): Separate registry service (#556) --- .github/workflows/test-python.yml | 2 +- Caddyfile | 3 + docker-compose.dev.yml | 41 +++- docker-compose.yml | 29 +++ tests/conftest.py | 5 + tests/unit/test_workflows.py | 23 +- tracecat/api/app.py | 133 +----------- tracecat/api/common.py | 124 +++++++++++ tracecat/api/registry.py | 92 ++++++++ tracecat/config.py | 8 +- tracecat/db/engine.py | 2 +- tracecat/dsl/models.py | 5 + tracecat/expressions/eval.py | 7 +- tracecat/identifiers/__init__.py | 1 + tracecat/registry/actions/router.py | 9 +- tracecat/registry/client.py | 23 +- tracecat/registry/constants.py | 6 + tracecat/registry/executor.py | 262 +++++++++++++---------- tracecat/registry/repositories/router.py | 3 +- 19 files changed, 494 insertions(+), 284 deletions(-) create mode 100644 tracecat/api/common.py create mode 100644 tracecat/api/registry.py diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 534283dd5..3b29154f5 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -126,7 +126,7 @@ jobs: - name: Start Docker services env: TRACECAT__UNSAFE_DISABLE_SM_MASKING: "true" - run: docker compose -f docker-compose.dev.yml up --build --no-deps -d api worker postgres_db caddy + run: docker compose -f docker-compose.dev.yml up --build --no-deps -d api worker registry postgres_db caddy - name: Install dependencies run: | diff --git a/Caddyfile b/Caddyfile index 159a63d84..16fe5f3bb 100644 --- a/Caddyfile +++ b/Caddyfile @@ -1,5 +1,8 @@ {$BASE_DOMAIN} { bind {$ADDRESS} # Binds to all available network interfaces if not specified + handle_path /api/registry* { + reverse_proxy http://registry:8000 + } handle_path /api* { reverse_proxy http://api:8000 } diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 54ef40a5a..d374102a6 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -36,6 +36,7 @@ services: TRACECAT__AUTH_TYPES: ${TRACECAT__AUTH_TYPES} TRACECAT__AUTH_ALLOWED_DOMAINS: ${TRACECAT__AUTH_ALLOWED_DOMAINS} TRACECAT__AUTH_MIN_PASSWORD_LENGTH: ${TRACECAT__AUTH_MIN_PASSWORD_LENGTH} + TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID} OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET} USER_AUTH_SECRET: ${USER_AUTH_SECRET} @@ -56,8 +57,6 @@ services: # AI TRACECAT__PRELOAD_OSS_MODELS: ${TRACECAT__PRELOAD_OSS_MODELS} OLLAMA__API_URL: ${OLLAMA__API_URL} - # This is only used for testing - TRACECAT__UNSAFE_DISABLE_SM_MASKING: ${TRACECAT__UNSAFE_DISABLE_SM_MASKING:-true} volumes: - ./tracecat:/app/tracecat - ./registry:/app/registry @@ -78,9 +77,10 @@ services: TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive - TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive + TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} + TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} # Temporal TEMPORAL__CLUSTER_URL: ${TEMPORAL__CLUSTER_URL} TEMPORAL__CLUSTER_QUEUE: ${TEMPORAL__CLUSTER_QUEUE} @@ -89,6 +89,40 @@ services: - ./registry:/app/registry entrypoint: ["python", "tracecat/dsl/worker.py"] + registry: + build: + context: . + dockerfile: Dockerfile.dev + restart: unless-stopped + environment: + # Common + LOG_LEVEL: ${LOG_LEVEL} + TRACECAT__APP_ENV: ${TRACECAT__APP_ENV} + TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive + TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} + TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive + TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive + TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive + # Registry + TRACECAT__REMOTE_REPOSITORY_URL: ${TRACECAT__REMOTE_REPOSITORY_URL} + TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME: ${TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME} + TRACECAT__UNSAFE_DISABLE_SM_MASKING: ${TRACECAT__UNSAFE_DISABLE_SM_MASKING:-false} + volumes: + - ./tracecat:/app/tracecat + - ./registry:/app/registry + entrypoint: + [ + "python", + "-m", + "uvicorn", + "tracecat.api.registry:app", + "--host", + "0.0.0.0", + "--port", + "8000", + "--reload", + ] + ui: build: context: ./frontend @@ -115,6 +149,7 @@ services: - ./frontend/src:/app/src - ./frontend/.next:/app/.next - ./frontend/node_modules:/app/node_modules + attach: false postgres_db: image: postgres:16 diff --git a/docker-compose.yml b/docker-compose.yml index e9f728dce..5e5a2c065 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -82,6 +82,35 @@ services: command: ["python", "tracecat/dsl/worker.py"] + registry: + image: ghcr.io/tracecathq/tracecat:${TRACECAT__IMAGE_TAG:-0.16.0} + restart: unless-stopped + networks: + - core + - core-db + environment: + # Common + LOG_LEVEL: ${LOG_LEVEL} + TRACECAT__APP_ENV: ${TRACECAT__APP_ENV} + TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive + TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} + TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive + TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive + # Registry + TRACECAT__REMOTE_REPOSITORY_URL: ${TRACECAT__REMOTE_REPOSITORY_URL} + TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME: ${TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME} + entrypoint: + [ + "python", + "-m", + "uvicorn", + "tracecat.api.registry:app", + "--host", + "0.0.0.0", + "--port", + "8000", + ] + ui: image: ghcr.io/tracecathq/tracecat-ui:${TRACECAT__IMAGE_TAG:-0.16.0} container_name: ui diff --git a/tests/conftest.py b/tests/conftest.py index 3cf5d8334..ccb4072c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ def env_sandbox(monkeysession: pytest.MonkeyPatch): load_dotenv() logger.info("Setting up environment variables") + monkeysession.setattr(config, "TRACECAT__APP_ENV", "development") monkeysession.setattr( config, "TRACECAT__DB_URI", @@ -58,6 +59,9 @@ def env_sandbox(monkeysession: pytest.MonkeyPatch): "TRACECAT__REMOTE_REPOSITORY_URL", "git+ssh://git@github.com/TracecatHQ/udfs.git", ) + monkeysession.setattr( + config, "TRACECAT__REGISTRY_URL", "http://localhost/api/registry" + ) monkeysession.setenv( "TRACECAT__DB_URI", @@ -65,6 +69,7 @@ def env_sandbox(monkeysession: pytest.MonkeyPatch): ) # monkeysession.setenv("TRACECAT__DB_ENCRYPTION_KEY", Fernet.generate_key().decode()) monkeysession.setenv("TRACECAT__API_URL", "http://api:8000") + monkeysession.setenv("TRACECAT__REGISTRY_URL", "http://registry:8000") monkeysession.setenv("TRACECAT__PUBLIC_API_URL", "http://localhost/api") monkeysession.setenv("TRACECAT__PUBLIC_RUNNER_URL", "http://localhost:8001") monkeysession.setenv("TRACECAT__SERVICE_KEY", os.environ["TRACECAT__SERVICE_KEY"]) diff --git a/tests/unit/test_workflows.py b/tests/unit/test_workflows.py index 95370aa10..9f7a9eb46 100644 --- a/tests/unit/test_workflows.py +++ b/tests/unit/test_workflows.py @@ -43,15 +43,15 @@ from tracecat.workflow.management.management import WorkflowsManagementService -@pytest.mark.skipif( - os.environ.get("GITHUB_ACTIONS") is not None, - reason="Skip if running in GitHub Actions", -) -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(scope="module") def hotfix_local_api_url(monkeysession: pytest.MonkeyPatch): - # NOTE: This is a hotfix to allow the workflow tests to run locally. - # We need to set the internal API url to the public API url - # otherwise the tests will fail because it cannot reach the internal API + """Hotfix to allow workflow tests to run locally. + + We need to set the internal API url to the public API url + otherwise the tests will fail because it cannot reach the internal API. + """ + if os.environ.get("GITHUB_ACTIONS") is not None: + pytest.skip("Skip if running in GitHub Actions") monkeysession.setattr(config, "TRACECAT__API_URL", "http://localhost/api") @@ -1578,6 +1578,7 @@ async def test_pull_based_workflow_fetches_latest_version(temporal_client, test_ assert result == "__EXPECTED_SECOND_RESULT__" +# Get the line number dynamically DIVISION_BY_ZERO_ERROR = { "ref": "start", "message": ( @@ -1598,9 +1599,9 @@ async def test_pull_based_workflow_fetches_latest_version(temporal_client, test_ "Cannot divide by zero\n" "\n" "------------------------------\n" - "File: /app/tracecat/expressions/core.py\n" - "Function: result\n" - "Line: 51" + "File: /app/tracecat/registry/executor.py\n" + "Function: run_action_in_pool\n" + "Line: 83" ), "type": "RegistryActionError", "expr_context": "ACTIONS", diff --git a/tracecat/api/app.py b/tracecat/api/app.py index 533e9e862..2832b06b7 100644 --- a/tracecat/api/app.py +++ b/tracecat/api/app.py @@ -1,16 +1,19 @@ from contextlib import asynccontextmanager -from urllib.parse import urlparse from fastapi import FastAPI, Request, status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse -from fastapi.routing import APIRoute from httpx_oauth.clients.google import GoogleOAuth2 from sqlalchemy.exc import IntegrityError from sqlmodel.ext.asyncio.session import AsyncSession from tracecat import config +from tracecat.api.common import ( + custom_generate_unique_id, + generic_exception_handler, + tracecat_exception_handler, +) from tracecat.auth.constants import AuthType from tracecat.auth.models import UserCreate, UserRead, UserUpdate from tracecat.auth.router import router as users_router @@ -26,16 +29,6 @@ from tracecat.logger import logger from tracecat.middleware import RequestLoggingMiddleware from tracecat.organization.router import router as org_router -from tracecat.registry.actions.router import router as registry_actions_router -from tracecat.registry.actions.service import RegistryActionsService -from tracecat.registry.constants import ( - CUSTOM_REPOSITORY_ORIGIN, - DEFAULT_REGISTRY_ORIGIN, -) -from tracecat.registry.repositories.models import RegistryRepositoryCreate -from tracecat.registry.repositories.router import router as registry_repos_router -from tracecat.registry.repositories.service import RegistryReposService -from tracecat.registry.repository import safe_url from tracecat.secrets.router import router as secrets_router from tracecat.types.auth import AccessLevel, Role from tracecat.types.exceptions import TracecatException @@ -57,64 +50,9 @@ async def lifespan(app: FastAPI): ) async with get_async_session_context_manager() as session: await setup_defaults(session, admin_role) - await setup_registry(session, admin_role) - await setup_oss_models() yield -async def setup_registry(session: AsyncSession, admin_role: Role): - logger.info("Setting up base registry repository") - repos_service = RegistryReposService(session, role=admin_role) - # Setup Tracecat base repository - base_origin = DEFAULT_REGISTRY_ORIGIN - # Check if the base registry repository already exists - # NOTE: Should we sync the base repo every time? - if await repos_service.get_repository(base_origin) is None: - base_repo = await repos_service.create_repository( - RegistryRepositoryCreate(origin=base_origin) - ) - logger.info("Created base registry repository", origin=base_origin) - actions_service = RegistryActionsService(session, role=admin_role) - await actions_service.sync_actions_from_repository(base_repo) - else: - logger.info("Base registry repository already exists", origin=base_origin) - - # Setup custom repository - custom_origin = CUSTOM_REPOSITORY_ORIGIN - if await repos_service.get_repository(custom_origin) is None: - await repos_service.create_repository( - RegistryRepositoryCreate(origin=custom_origin) - ) - logger.info("Created custom repository", origin=custom_origin) - else: - logger.info("Custom repository already exists", origin=custom_origin) - - # Setup custom remote repository - if (remote_url := config.TRACECAT__REMOTE_REPOSITORY_URL) is not None: - parsed_url = urlparse(remote_url) - logger.info("Setting up remote registry repository", url=parsed_url) - # Create it if it doesn't exist - - cleaned_url = safe_url(remote_url) - if await repos_service.get_repository(cleaned_url) is None: - await repos_service.create_repository( - RegistryRepositoryCreate(origin=cleaned_url) - ) - logger.info("Created remote registry repository", url=cleaned_url) - else: - logger.info("Remote registry repository already exists", url=cleaned_url) - # Load remote repository - else: - logger.info("Remote registry repository not set, skipping") - - repos = await repos_service.list_repositories() - logger.info( - "Found registry repositories", - n=len(repos), - repos=[repo.origin for repo in repos], - ) - - async def setup_defaults(session: AsyncSession, admin_role: Role): ws_service = WorkspaceService(session, role=admin_role) workspaces = await ws_service.admin_list_workspaces() @@ -129,58 +67,7 @@ async def setup_defaults(session: AsyncSession, admin_role: Role): logger.info("Default workspace already exists, skipping") -async def setup_oss_models(): - if not (preload_models := config.TRACECAT__PRELOAD_OSS_MODELS): - return - from tracecat.llm import preload_ollama_models - - logger.info( - f"Preloading {len(preload_models)} models", - models=preload_models, - ) - await preload_ollama_models(preload_models) - logger.info("Preloaded models", models=preload_models) - - -def custom_generate_unique_id(route: APIRoute): - if route.tags: - return f"{route.tags[0]}-{route.name}" - return route.name - - # Catch-all exception handler to prevent stack traces from leaking -def generic_exception_handler(request: Request, exc: Exception): - logger.error( - "Unexpected error", - exc=exc, - role=ctx_role.get(), - params=request.query_params, - path=request.url.path, - ) - return ORJSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"message": "An unexpected error occurred. Please try again later."}, - ) - - -def tracecat_exception_handler(request: Request, exc: TracecatException): - """Generic exception handler for Tracecat exceptions. - - We can customize exceptions to expose only what should be user facing. - """ - msg = str(exc) - logger.error( - msg, - role=ctx_role.get(), - params=request.query_params, - path=request.url.path, - ) - return ORJSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"type": type(exc).__name__, "message": msg, "detail": exc.detail}, - ) - - def validation_exception_handler(request: Request, exc: RequestValidationError): """Improves visiblity of 422 errors.""" exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") @@ -208,7 +95,6 @@ def fastapi_users_auth_exception_handler(request: Request, exc: FastAPIUsersExce def create_app(**kwargs) -> FastAPI: - global logger if config.TRACECAT__ALLOW_ORIGINS is not None: allow_origins = config.TRACECAT__ALLOW_ORIGINS.split(",") else: @@ -251,8 +137,6 @@ def create_app(**kwargs) -> FastAPI: app.include_router(secrets_router) app.include_router(schedules_router) app.include_router(users_router) - app.include_router(registry_repos_router) - app.include_router(registry_actions_router) app.include_router(org_router) app.include_router(editor_router) app.include_router( @@ -326,11 +210,8 @@ def create_app(**kwargs) -> FastAPI: # Exception handlers app.add_exception_handler(Exception, generic_exception_handler) - app.add_exception_handler(TracecatException, tracecat_exception_handler) # type: ignore - app.add_exception_handler( - RequestValidationError, - validation_exception_handler, # type: ignore - ) + app.add_exception_handler(TracecatException, tracecat_exception_handler) # type: ignore # type: ignore + app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore app.add_exception_handler( FastAPIUsersException, fastapi_users_auth_exception_handler, # type: ignore diff --git a/tracecat/api/common.py b/tracecat/api/common.py new file mode 100644 index 000000000..283c87544 --- /dev/null +++ b/tracecat/api/common.py @@ -0,0 +1,124 @@ +from urllib.parse import urlparse + +from fastapi import Request, status +from fastapi.responses import ORJSONResponse +from fastapi.routing import APIRoute +from sqlmodel.ext.asyncio.session import AsyncSession + +from tracecat import config +from tracecat.contexts import ctx_role +from tracecat.logger import logger +from tracecat.registry.actions.service import RegistryActionsService +from tracecat.registry.constants import ( + CUSTOM_REPOSITORY_ORIGIN, + DEFAULT_REGISTRY_ORIGIN, +) +from tracecat.registry.repositories.models import RegistryRepositoryCreate +from tracecat.registry.repositories.service import RegistryReposService +from tracecat.registry.repository import safe_url +from tracecat.types.auth import Role +from tracecat.types.exceptions import TracecatException + + +def generic_exception_handler(request: Request, exc: Exception): + logger.error( + "Unexpected error", + exc=exc, + role=ctx_role.get(), + params=request.query_params, + path=request.url.path, + ) + return ORJSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"message": "An unexpected error occurred. Please try again later."}, + ) + + +async def setup_registry(session: AsyncSession, admin_role: Role): + logger.info("Setting up base registry repository") + repos_service = RegistryReposService(session, role=admin_role) + # Setup Tracecat base repository + base_origin = DEFAULT_REGISTRY_ORIGIN + # Check if the base registry repository already exists + # NOTE: Should we sync the base repo every time? + if await repos_service.get_repository(base_origin) is None: + base_repo = await repos_service.create_repository( + RegistryRepositoryCreate(origin=base_origin) + ) + logger.info("Created base registry repository", origin=base_origin) + actions_service = RegistryActionsService(session, role=admin_role) + await actions_service.sync_actions_from_repository(base_repo) + else: + logger.info("Base registry repository already exists", origin=base_origin) + + # Setup custom repository + custom_origin = CUSTOM_REPOSITORY_ORIGIN + if await repos_service.get_repository(custom_origin) is None: + await repos_service.create_repository( + RegistryRepositoryCreate(origin=custom_origin) + ) + logger.info("Created custom repository", origin=custom_origin) + else: + logger.info("Custom repository already exists", origin=custom_origin) + + # Setup custom remote repository + if (remote_url := config.TRACECAT__REMOTE_REPOSITORY_URL) is not None: + parsed_url = urlparse(remote_url) + logger.info("Setting up remote registry repository", url=parsed_url) + # Create it if it doesn't exist + + cleaned_url = safe_url(remote_url) + if await repos_service.get_repository(cleaned_url) is None: + await repos_service.create_repository( + RegistryRepositoryCreate(origin=cleaned_url) + ) + logger.info("Created remote registry repository", url=cleaned_url) + else: + logger.info("Remote registry repository already exists", url=cleaned_url) + # Load remote repository + else: + logger.info("Remote registry repository not set, skipping") + + repos = await repos_service.list_repositories() + logger.info( + "Found registry repositories", + n=len(repos), + repos=[repo.origin for repo in repos], + ) + + +def tracecat_exception_handler(request: Request, exc: TracecatException): + """Generic exception handler for Tracecat exceptions. + + We can customize exceptions to expose only what should be user facing. + """ + msg = str(exc) + logger.error( + msg, + role=ctx_role.get(), + params=request.query_params, + path=request.url.path, + ) + return ORJSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"type": type(exc).__name__, "message": msg, "detail": exc.detail}, + ) + + +def custom_generate_unique_id(route: APIRoute): + if route.tags: + return f"{route.tags[0]}-{route.name}" + return route.name + + +async def setup_oss_models(): + if not (preload_models := config.TRACECAT__PRELOAD_OSS_MODELS): + return + from tracecat.llm import preload_ollama_models + + logger.info( + f"Preloading {len(preload_models)} models", + models=preload_models, + ) + await preload_ollama_models(preload_models) + logger.info("Preloaded models", models=preload_models) diff --git a/tracecat/api/registry.py b/tracecat/api/registry.py new file mode 100644 index 000000000..796955803 --- /dev/null +++ b/tracecat/api/registry.py @@ -0,0 +1,92 @@ +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse + +from tracecat import config +from tracecat.api.common import ( + custom_generate_unique_id, + generic_exception_handler, + setup_oss_models, + setup_registry, + tracecat_exception_handler, +) +from tracecat.db.engine import get_async_session_context_manager +from tracecat.logger import logger +from tracecat.middleware import RequestLoggingMiddleware +from tracecat.registry.actions.router import router as registry_actions_router +from tracecat.registry.executor import get_executor +from tracecat.registry.repositories.router import router as registry_repos_router +from tracecat.types.auth import AccessLevel, Role +from tracecat.types.exceptions import TracecatException + + +@asynccontextmanager +async def lifespan(app: FastAPI): + admin_role = Role( + type="service", + access_level=AccessLevel.ADMIN, + service_id="tracecat-registry", + ) + async with get_async_session_context_manager() as session: + await setup_registry(session, admin_role) + await setup_oss_models() + try: + executor = get_executor() + yield + finally: + executor.shutdown() + + +def create_app(**kwargs) -> FastAPI: + if config.TRACECAT__ALLOW_ORIGINS is not None: + allow_origins = config.TRACECAT__ALLOW_ORIGINS.split(",") + else: + allow_origins = ["*"] + app = FastAPI( + title="Tracecat Registry", + description="Registry action executor.", + summary="Tracecat Registry", + lifespan=lifespan, + default_response_class=ORJSONResponse, + generate_unique_id_function=custom_generate_unique_id, + root_path="/api/registry", + **kwargs, + ) + app.logger = logger # type: ignore + + # Routers + app.include_router(registry_repos_router) + app.include_router(registry_actions_router) + + # Exception handlers + app.add_exception_handler(Exception, generic_exception_handler) + app.add_exception_handler(TracecatException, tracecat_exception_handler) # type: ignore + + # Middleware + app.add_middleware(RequestLoggingMiddleware) + app.add_middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + logger.info( + "Registry service started", + env=config.TRACECAT__APP_ENV, + origins=allow_origins, + auth_types=config.TRACECAT__AUTH_TYPES, + ) + + return app + + +app = create_app() + + +@app.get("/", include_in_schema=False) +def root() -> dict[str, str]: + return {"message": "Hello world. I am the registry."} diff --git a/tracecat/config.py b/tracecat/config.py index 5d84f6315..07507ef63 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -15,8 +15,9 @@ TRACECAT__SCHEDULE_MAX_CONNECTIONS = 6 TRACECAT__APP_ENV: Literal["development", "staging", "production"] = os.environ.get( "TRACECAT__APP_ENV", "development" -) +) # type: ignore TRACECAT__API_URL = os.environ.get("TRACECAT__API_URL", "http://localhost:8000") +TRACECAT__API_ROOT_PATH = os.environ.get("TRACECAT__API_ROOT_PATH", "/api") TRACECAT__PUBLIC_RUNNER_URL = os.environ.get( "TRACECAT__PUBLIC_RUNNER_URL", "http://localhost/api" ) @@ -31,6 +32,9 @@ "TRACECAT__DB_URI", "postgresql+psycopg://postgres:postgres@postgres_db:5432/postgres", ) +TRACECAT__REGISTRY_URL = os.environ.get( + "TRACECAT__REGISTRY_URL", "http://registry:8000" +) TRACECAT__DB_NAME = os.environ.get("TRACECAT__DB_NAME") TRACECAT__DB_USER = os.environ.get("TRACECAT__DB_USER") @@ -39,8 +43,6 @@ TRACECAT__DB_ENDPOINT = os.environ.get("TRACECAT__DB_ENDPOINT") TRACECAT__DB_PORT = os.environ.get("TRACECAT__DB_PORT") -TRACECAT__API_ROOT_PATH = os.environ.get("TRACECAT__API_ROOT_PATH", "/api") - # TODO: Set this as an environment variable TRACECAT__SERVICE_ROLES_WHITELIST = [ "tracecat-runner", diff --git a/tracecat/db/engine.py b/tracecat/db/engine.py index cf89ff346..0315ec03c 100644 --- a/tracecat/db/engine.py +++ b/tracecat/db/engine.py @@ -143,7 +143,7 @@ def get_session() -> Generator[Session, None, None]: yield session -async def get_async_session() -> AsyncGenerator[AsyncSession, None, None]: +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async_engine = get_async_engine() async with AsyncSession(async_engine, expire_on_commit=False) as async_session: yield async_session diff --git a/tracecat/dsl/models.py b/tracecat/dsl/models.py index 48df77522..50359636f 100644 --- a/tracecat/dsl/models.py +++ b/tracecat/dsl/models.py @@ -166,6 +166,8 @@ class DSLEnvironment(TypedDict, total=False): class DSLContext(TypedDict, total=False): + """DSL Context. Contains all the context needed to execute a DSL workflow.""" + INPUTS: dict[str, Any] """DSL Static Inputs context""" @@ -178,6 +180,9 @@ class DSLContext(TypedDict, total=False): ENV: DSLEnvironment """DSL Environment context. Has metadata about the workflow.""" + SECRETS: dict[str, Any] + """DSL Secrets context""" + class RunActionInput(BaseModel): """This object contains all the information needed to execute an action.""" diff --git a/tracecat/expressions/eval.py b/tracecat/expressions/eval.py index 551211edf..464f26a6c 100644 --- a/tracecat/expressions/eval.py +++ b/tracecat/expressions/eval.py @@ -9,8 +9,7 @@ T = TypeVar("T", str, list[Any], dict[str, Any]) - -OperatorType = Callable[[re.Match[str]], str] +OperatorType = Callable[[str], Any] OperandType = Mapping[str, Any] @@ -113,8 +112,8 @@ def operator(line: str) -> None: def get_iterables_from_expression( expr: str | list[str], operand: OperandType -) -> list[IterableExpr]: - iterable_exprs: IterableExpr | list[IterableExpr] = eval_templated_object( +) -> list[IterableExpr[Any]]: + iterable_exprs: IterableExpr[Any] | list[IterableExpr[Any]] = eval_templated_object( expr, operand=operand ) if isinstance(iterable_exprs, IterableExpr): diff --git a/tracecat/identifiers/__init__.py b/tracecat/identifiers/__init__.py index 9a535ac52..d883a4778 100644 --- a/tracecat/identifiers/__init__.py +++ b/tracecat/identifiers/__init__.py @@ -71,6 +71,7 @@ "tracecat-cli", "tracecat-schedule-runner", "tracecat-service", + "tracecat-registry", ] __all__ = [ diff --git a/tracecat/registry/actions/router.py b/tracecat/registry/actions/router.py index dc407d768..8af7da1be 100644 --- a/tracecat/registry/actions/router.py +++ b/tracecat/registry/actions/router.py @@ -6,7 +6,7 @@ from tracecat.auth.credentials import RoleACL from tracecat.concurrency import GatheringTaskGroup -from tracecat.contexts import ctx_logger +from tracecat.contexts import ctx_logger, ctx_role from tracecat.db.dependencies import AsyncDBSession from tracecat.dsl.models import RunActionInput from tracecat.logger import logger @@ -20,12 +20,12 @@ RegistryActionValidateResponse, ) from tracecat.registry.actions.service import RegistryActionsService -from tracecat.registry.constants import DEFAULT_REGISTRY_ORIGIN +from tracecat.registry.constants import DEFAULT_REGISTRY_ORIGIN, REGISTRY_ACTIONS_PATH from tracecat.types.auth import AccessLevel, Role from tracecat.types.exceptions import RegistryError from tracecat.validation.service import validate_registry_action_args -router = APIRouter(prefix="/registry/actions", tags=["registry-actions"]) +router = APIRouter(prefix=REGISTRY_ACTIONS_PATH, tags=["registry-actions"]) @router.get("") @@ -172,12 +172,13 @@ async def run_registry_action( ) -> Any: """Execute a registry action.""" ref = action_input.task.ref + ctx_role.set(role) act_logger = logger.bind(role=role, action_name=action_name, ref=ref) ctx_logger.set(act_logger) act_logger.info("Starting action") try: - return await executor.run_action_from_input(input=action_input) + return await executor.run_action_in_pool(input=action_input) except Exception as e: # Get the traceback info tb = traceback.extract_tb(e.__traceback__)[-1] # Get the last frame diff --git a/tracecat/registry/client.py b/tracecat/registry/client.py index 926e4919f..43bc50749 100644 --- a/tracecat/registry/client.py +++ b/tracecat/registry/client.py @@ -17,6 +17,7 @@ RegistryActionRead, RegistryActionValidateResponse, ) +from tracecat.registry.constants import REGISTRY_ACTIONS_PATH, REGISTRY_REPOS_PATH from tracecat.types.auth import Role from tracecat.types.exceptions import RegistryActionError, RegistryError @@ -25,7 +26,7 @@ class _RegistryHTTPClient(AuthenticatedServiceClient): """Async httpx client for the registry service.""" def __init__(self, role: Role | None = None, *args: Any, **kwargs: Any) -> None: - self._registry_base_url = config.TRACECAT__API_URL + self._registry_base_url = config.TRACECAT__REGISTRY_URL super().__init__(role, *args, base_url=self._registry_base_url, **kwargs) self.params = self.params.add("workspace_id", str(self.role.workspace_id)) @@ -33,13 +34,13 @@ def __init__(self, role: Role | None = None, *args: Any, **kwargs: Any) -> None: class RegistryClient: """Use this to interact with the remote registry service.""" - _repos_endpoint = "/registry/repos" - _actions_endpoint = "/registry/actions" + _repos_endpoint = REGISTRY_REPOS_PATH + _actions_endpoint = REGISTRY_ACTIONS_PATH _timeout: float = 60.0 def __init__(self, role: Role | None = None): self.role = role or ctx_role.get() - self.logger = logger.bind(service="remote-registry", role=self.role) + self.logger = logger.bind(service="registry-client", role=self.role) """Execution""" @@ -72,11 +73,11 @@ async def call_action(self, input: RunActionInput) -> Any: but are included in the method signature for potential future use. """ - key = input.task.action + action_type = input.task.action content = input.model_dump_json() workspace_id = str(self.role.workspace_id) if self.role.workspace_id else None logger.debug( - f"Calling action {key!r} with content", + f"Calling action {action_type!r} with content", content=content, role=self.role, timeout=self._timeout, @@ -84,7 +85,7 @@ async def call_action(self, input: RunActionInput) -> Any: try: async with _RegistryHTTPClient(self.role) as client: response = await client.post( - f"{self._actions_endpoint}/{key}/execute", + f"{self._actions_endpoint}/{action_type}/execute", # NOTE(perf): Maybe serialize with orjson.dumps instead headers={ "Content-Type": "application/json", @@ -116,7 +117,7 @@ async def call_action(self, input: RunActionInput) -> Any: logger.error("Registry returned an error", error=e, detail=detail) if e.response.status_code / 100 == 5: raise RegistryActionError( - f"There was an error in the registry when calling action {key!r} ({e.response.status_code}).\n\n{detail}" + f"There was an error in the registry when calling action {action_type!r} ({e.response.status_code}).\n\n{detail}" ) from e else: raise RegistryActionError( @@ -124,15 +125,15 @@ async def call_action(self, input: RunActionInput) -> Any: ) from e except httpx.ReadTimeout as e: raise RegistryActionError( - f"Timeout calling action {key!r} in registry: {e}" + f"Timeout calling action {action_type!r} in registry: {e}" ) from e except orjson.JSONDecodeError as e: raise RegistryActionError( - f"Error decoding JSON response for action {key!r}: {e}" + f"Error decoding JSON response for action {action_type!r}: {e}" ) from e except Exception as e: raise RegistryActionError( - f"Unexpected error calling action {key!r} in registry: {e}" + f"Unexpected error calling action {action_type!r} in registry: {e}" ) from e """Validation""" diff --git a/tracecat/registry/constants.py b/tracecat/registry/constants.py index 1ba85bd3f..25167d959 100644 --- a/tracecat/registry/constants.py +++ b/tracecat/registry/constants.py @@ -2,3 +2,9 @@ DEFAULT_REMOTE_REGISTRY_ORIGIN = "remote" CUSTOM_REPOSITORY_ORIGIN = "custom" GITHUB_SSH_KEY_SECRET_NAME = "github-ssh-key" + +REGISTRY_REPOS_PATH: str = "/repos" +"""Base path for repository-related endpoints""" + +REGISTRY_ACTIONS_PATH: str = "/actions" +"""Base path for action-related endpoints""" diff --git a/tracecat/registry/executor.py b/tracecat/registry/executor.py index 4cdb8ada0..ffd7de38a 100644 --- a/tracecat/registry/executor.py +++ b/tracecat/registry/executor.py @@ -7,12 +7,16 @@ import asyncio from collections.abc import Iterator, Mapping +from concurrent.futures import ProcessPoolExecutor from typing import Any, cast +import uvloop + from tracecat import config from tracecat.auth.sandbox import AuthSandbox from tracecat.concurrency import GatheringTaskGroup -from tracecat.contexts import ctx_logger, ctx_run +from tracecat.contexts import ctx_logger, ctx_role, ctx_run +from tracecat.db.engine import get_async_engine from tracecat.dsl.common import context_locator, create_default_dsl_context from tracecat.dsl.models import ( ActionStatement, @@ -21,11 +25,12 @@ RunActionInput, ) from tracecat.expressions.eval import ( + OperandType, eval_templated_object, extract_templated_secrets, get_iterables_from_expression, ) -from tracecat.expressions.shared import ExprContext, ExprContextType +from tracecat.expressions.shared import ExprContext from tracecat.logger import logger from tracecat.parse import traverse_leaves from tracecat.registry.actions.models import ArgsClsT, BoundRegistryAction @@ -33,12 +38,53 @@ from tracecat.secrets.common import apply_masks_object from tracecat.secrets.constants import DEFAULT_SECRETS_ENVIRONMENT from tracecat.secrets.secrets_manager import env_sandbox +from tracecat.types.auth import Role from tracecat.types.exceptions import TracecatException """All these methods are used in the registry executor, not on the worker""" type ArgsT = Mapping[str, Any] +_executor: ProcessPoolExecutor | None = None + +# We want to be able to serve a looped action +# Before we send out tasks to the executor we should inspect the size of the loop +# and set the right chunk size for each worker + + +def get_executor() -> ProcessPoolExecutor: + """Get the executor, creating it if it doesn't exist""" + global _executor + if _executor is None: + _executor = ProcessPoolExecutor() + return _executor + + +def sync_executor_entrypoint(input: RunActionInput[ArgsT], role: Role) -> Any: + """Run an action on the executor (API, not worker)""" + + logger.info("Running action in pool", input=input) + + async def coro(): + ctx_role.set(role) + async_engine = get_async_engine() + try: + return await run_action_from_input(input=input) + finally: + await async_engine.dispose() + + return uvloop.run(coro()) + + +async def run_action_in_pool(input: RunActionInput[ArgsT]) -> Any: + """Run an action on the executor (API, not worker)""" + loop = asyncio.get_running_loop() + role = ctx_role.get() + result = await loop.run_in_executor( + get_executor(), sync_executor_entrypoint, input, role + ) + return result + async def _run_action_direct( *, action: BoundRegistryAction[ArgsClsT], args: ArgsT, validate: bool = False @@ -68,69 +114,23 @@ async def _run_action_direct( async def run_single_action( *, - action_name: str, + action: BoundRegistryAction[ArgsClsT], args: ArgsT, - context: dict[str, Any] | None = None, + context: DSLContext | None = None, ) -> Any: """Run a UDF async.""" - # NOTE(perf): We might want to cache this, or call at a higher level - async with RegistryActionsService.with_session() as service: - action = await service.load_action_impl(action_name=action_name) - validated_args = action.validate_args(**args) - - logger.trace("Running regular UDF async", action=action_name) - secret_names = [secret.name for secret in action.secrets or []] - optional_secrets = [ - secret.name for secret in action.secrets or [] if secret.optional - ] - run_context = ctx_run.get() - environment = getattr(run_context, "environment", DEFAULT_SECRETS_ENVIRONMENT) - async with ( - AuthSandbox( - secrets=secret_names, - target="context", - environment=environment, - optional_secrets=optional_secrets, - ) as sandbox, - ): - # Flatten the secrets to a dict[str, str] - secret_context = sandbox.secrets.copy() - if action.is_template: - logger.info("Running template UDF async", action=action_name) - context_with_secrets = context.copy() if context else {} - # Merge the secrets from the sandbox with the existing context - context_with_secrets[ExprContext.SECRETS] = ( - context_with_secrets.get(ExprContext.SECRETS, {}) | secret_context - ) - return await run_template_action( - action=action, - args=validated_args, - context=context_with_secrets, - ) - # Given secrets in the format of {name: {key: value}}, we need to flatten - # it to a dict[str, str] to set in the environment context - flattened_secrets: dict[str, str] = {} - for name, keyvalues in secret_context.items(): - for key, value in keyvalues.items(): - if key in flattened_secrets: - raise ValueError( - f"Key {key!r} is duplicated in {name!r}! " - "Please ensure only one secret with a given name is set. " - "e.g. If you have `first_secret.KEY` set, then you cannot " - "also set `second_secret.KEY` as `KEY` is duplicated." - ) - flattened_secrets[key] = value - - with env_sandbox(flattened_secrets): - # Run the UDF in the caller process (usually the worker) - return await _run_action_direct(action=action, args=validated_args) + if action.is_template: + logger.info("Running template UDF async", action=action.name) + return await run_template_action(action=action, args=args, context=context) + # Run the UDF in the caller process (usually the worker) + return await _run_action_direct(action=action, args=args) async def run_template_action( *, action: BoundRegistryAction[ArgsClsT], args: ArgsT, - context: DSLContext, + context: DSLContext | None = None, ) -> Any: """Handle template execution. @@ -146,9 +146,11 @@ async def run_template_action( ) defn = action.template_action.definition template_context = cast( - ExprContextType, - context.copy() - | { + DSLContext, + { + ExprContext.SECRETS: {} + if context is None + else context.get(ExprContext.SECRETS, {}), ExprContext.TEMPLATE_ACTION_INPUTS: args, ExprContext.TEMPLATE_ACTION_STEPS: {}, }, @@ -157,10 +159,15 @@ async def run_template_action( for step in defn.steps: evaled_args = cast( - ArgsT, eval_templated_object(step.args, operand=template_context) + ArgsT, + eval_templated_object( + step.args, operand=cast(OperandType, template_context) + ), ) + async with RegistryActionsService.with_session() as service: + step_action = await service.load_action_impl(action_name=step.action) result = await run_single_action( - action_name=step.action, + action=step_action, args=evaled_args, context=template_context, ) @@ -172,13 +179,15 @@ async def run_template_action( ) # Handle returns - return eval_templated_object(defn.returns, operand=template_context) + return eval_templated_object( + defn.returns, operand=cast(OperandType, template_context) + ) async def run_action_from_input(input: RunActionInput) -> Any: """This runs on the executor (API, not worker)""" ctx_run.set(input.run_context) - act_logger = ctx_logger.get() + act_logger = ctx_logger.get(logger.bind(ref=input.task.ref)) task = input.task environment = input.run_context.environment @@ -200,18 +209,26 @@ async def run_action_from_input(input: RunActionInput) -> Any: # 2. Load the secrets # 3. Inject the secrets into the task arguments using an enriched context # NOTE: Regardless of loop iteration, we should only make this call/substitution once!! - secret_refs = extract_templated_secrets(task.args) + + async with RegistryActionsService.with_session() as service: + action = await service.load_action_impl(action_name=action_name) + + run_context = ctx_run.get() + environment = getattr(run_context, "environment", DEFAULT_SECRETS_ENVIRONMENT) + + action_secret_names = {secret.name for secret in action.secrets or []} + optional_secrets = { + secret.name for secret in action.secrets or [] if secret.optional + } + args_secret_refs = set(extract_templated_secrets(task.args)) async with AuthSandbox( - secrets=secret_refs, target="context", environment=environment + secrets=list(action_secret_names | args_secret_refs), + target="context", + environment=environment, + optional_secrets=list(optional_secrets), ) as sandbox: secrets = sandbox.secrets.copy() - context_with_secrets = DSLContext( - **{ - **input.exec_context, - ExprContext.SECRETS: secrets, - } - ) if config.TRACECAT__UNSAFE_DISABLE_SM_MASKING: act_logger.warning( @@ -231,40 +248,55 @@ async def run_action_from_input(input: RunActionInput) -> Any: args=task.args, ) - # Actual execution - if task.for_each: - iterator = iter_for_each(task=task, context=context_with_secrets) - try: - async with GatheringTaskGroup() as tg: - for patched_args in iterator: - tg.create_task( - run_single_action( - action_name=action_name, - args=patched_args, - context=context_with_secrets, + context = input.exec_context.copy() + context.update(SECRETS=secrets) + + # Given secrets in the format of {name: {key: value}}, we need to flatten + # it to a dict[str, str] to set in the environment context + flattened_secrets: dict[str, str] = {} + for name, keyvalues in secrets.items(): + for key, value in keyvalues.items(): + if key in flattened_secrets: + raise ValueError( + f"Key {key!r} is duplicated in {name!r}! " + "Please ensure only one secret with a given name is set. " + "e.g. If you have `first_secret.KEY` set, then you cannot " + "also set `second_secret.KEY` as `KEY` is duplicated." + ) + flattened_secrets[key] = value + + with env_sandbox(flattened_secrets): + # Actual execution + if task.for_each: + # If the action is CPU bound, just run it directly + # Otherwise, we want to parallelize it + iterator = iter_for_each(task=task, context=context) + + try: + async with GatheringTaskGroup() as tg: + for patched_args in iterator: + tg.create_task( + run_single_action( + action=action, args=patched_args, context=context + ) ) - ) - - result = tg.results() - except* Exception as eg: - errors = [str(x) for x in eg.exceptions] - logger.error("Error resolving expressions", errors=errors) - raise TracecatException( - ( - f"[{context_locator(task, 'for_each')}]" - "\n\nError in loop:" - f"\n\n{'\n\n'.join(errors)}" - ), - detail={"errors": errors}, - ) from eg - else: - args = evaluate_templated_args(task, context_with_secrets) - result = await run_single_action( - action_name=action_name, - args=args, - context=cast(dict[str, Any], context_with_secrets), - ) + result = tg.results() + except* Exception as eg: + errors = [str(x) for x in eg.exceptions] + logger.error("Error resolving expressions", errors=errors) + raise TracecatException( + ( + f"[{context_locator(task, 'for_each')}]" + "\n\nError in loop:" + f"\n\n{'\n\n'.join(errors)}" + ), + detail={"errors": errors}, + ) from eg + + else: + args = evaluate_templated_args(task, context) + result = await run_single_action(action=action, args=args, context=context) if mask_values: result = apply_masks_object(result, masks=mask_values) @@ -300,28 +332,20 @@ def iter_for_each( raise ValueError("No loop expression found") iterators = get_iterables_from_expression(expr=task.for_each, operand=context) - # Assert that all length of the iterables are the same - # This is a requirement for parallel processing - # if len({len(expr.collection) for expr in iterators}) != 1: - # raise ValueError("All iterables must be of the same length") + # Patch the context with the loop item and evaluate the action-local expressions + # We're copying this so that we don't pollute the original context + # Currently, the only source of action-local expressions is the loop iteration + # In the future, we may have other sources of action-local expressions + # XXX: ENV is the only context that should be shared + patched_context = context.copy() if patch else create_default_dsl_context() + logger.trace("Context before patch", patched_context=patched_context) # Create a generator that zips the iterables together for i, items in enumerate(zip(*iterators, strict=False)): logger.trace("Loop iteration", iteration=i) - # Patch the context with the loop item and evaluate the action-local expressions - # We're copying this so that we don't pollute the original context - # Currently, the only source of action-local expressions is the loop iteration - # In the future, we may have other sources of action-local expressions - patched_context = ( - context.copy() - if patch - # XXX: ENV is the only context that should be shared - else create_default_dsl_context() - ) - logger.trace("Context before patch", patched_context=patched_context) for iterator_path, iterator_value in items: patch_object( - cast(dict[str, Any], patched_context), + obj=patched_context, # type: ignore path=assign_context + iterator_path, value=iterator_value, ) diff --git a/tracecat/registry/repositories/router.py b/tracecat/registry/repositories/router.py index 88c2fe2a7..de5b419fb 100644 --- a/tracecat/registry/repositories/router.py +++ b/tracecat/registry/repositories/router.py @@ -10,6 +10,7 @@ from tracecat.registry.constants import ( CUSTOM_REPOSITORY_ORIGIN, DEFAULT_REGISTRY_ORIGIN, + REGISTRY_REPOS_PATH, ) from tracecat.registry.repositories.models import ( RegistryRepositoryCreate, @@ -22,7 +23,7 @@ from tracecat.types.auth import AccessLevel, Role from tracecat.types.exceptions import RegistryError, TracecatNotFoundError -router = APIRouter(prefix="/registry/repos", tags=["registry-repositories"]) +router = APIRouter(prefix=REGISTRY_REPOS_PATH, tags=["registry-repositories"]) # Controls From 04f9a027d26a8202730d0a91ae8e923cb4f76d6c Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 17:13:47 -0800 Subject: [PATCH 02/13] ci(infra): Add registry service to fargate --- deployments/aws/ecs/ecs-registry.tf | 91 +++++++++++++++++++++++++++++ deployments/aws/ecs/iam.tf | 6 ++ deployments/aws/ecs/locals.tf | 10 ++++ deployments/aws/ecs/variables.tf | 10 ++++ deployments/aws/main.tf | 2 + deployments/aws/variables.tf | 10 ++++ docker-compose.yml | 3 +- 7 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 deployments/aws/ecs/ecs-registry.tf diff --git a/deployments/aws/ecs/ecs-registry.tf b/deployments/aws/ecs/ecs-registry.tf new file mode 100644 index 000000000..6f46912f1 --- /dev/null +++ b/deployments/aws/ecs/ecs-registry.tf @@ -0,0 +1,91 @@ +# ECS Task Definition for Registry Service +resource "aws_ecs_task_definition" "registry_task_definition" { + family = "TracecatRegistryTaskDefinition" + network_mode = "awsvpc" + requires_compatibilities = ["FARGATE"] + cpu = var.registry_cpu + memory = var.registry_memory + execution_role_arn = aws_iam_role.worker_execution.arn + task_role_arn = aws_iam_role.api_worker_task.arn + + container_definitions = jsonencode([ + { + name = "TracecatRegistryContainer" + image = "${var.tracecat_image}:${local.tracecat_image_tag}" + command = [ + "python", + "-m", + "uvicorn", + "tracecat.api.registry:app", + "--host", + "0.0.0.0", + "--port", + "8000" + ] + portMappings = [ + { + containerPort = 8000 + hostPort = 8000 + name = "registry" + appProtocol = "http" + } + ] + logConfiguration = { + logDriver = "awslogs" + options = { + awslogs-group = aws_cloudwatch_log_group.tracecat_log_group.name + awslogs-region = var.aws_region + awslogs-stream-prefix = "registry" + } + } + environment = local.registry_env + secrets = local.tracecat_secrets + dockerPullConfig = { + maxAttempts = 3 + backoffTime = 10 + } + } + ]) +} + +resource "aws_ecs_service" "tracecat_registry" { + name = "tracecat-registry" + cluster = aws_ecs_cluster.tracecat_cluster.id + task_definition = aws_ecs_task_definition.registry_task_definition.arn + launch_type = "FARGATE" + desired_count = 1 + force_new_deployment = var.force_new_deployment + + network_configuration { + subnets = var.private_subnet_ids + security_groups = [ + aws_security_group.core.id, + aws_security_group.core_db.id, + ] + } + + service_connect_configuration { + enabled = true + namespace = local.local_dns_namespace + service { + port_name = "registry" + discovery_name = "registry-service" + timeout { + per_request_timeout_seconds = 120 + } + client_alias { + port = 8000 + dns_name = "registry-service" + } + } + + log_configuration { + log_driver = "awslogs" + options = { + awslogs-group = aws_cloudwatch_log_group.tracecat_log_group.name + awslogs-region = var.aws_region + awslogs-stream-prefix = "service-connect-registry" + } + } + } +} diff --git a/deployments/aws/ecs/iam.tf b/deployments/aws/ecs/iam.tf index 813c47d7b..9302d96ab 100644 --- a/deployments/aws/ecs/iam.tf +++ b/deployments/aws/ecs/iam.tf @@ -110,6 +110,12 @@ resource "aws_iam_role_policy_attachment" "worker_execution_secrets" { role = aws_iam_role.worker_execution.name } +# Registry execution role +resource "aws_iam_role" "registry_execution" { + name = "TracecatRegistryExecutionRole" + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + # UI execution role resource "aws_iam_role" "ui_execution" { name = "TracecatUIExecutionRole" diff --git a/deployments/aws/ecs/locals.tf b/deployments/aws/ecs/locals.tf index 10e076a6e..25164252d 100644 --- a/deployments/aws/ecs/locals.tf +++ b/deployments/aws/ecs/locals.tf @@ -58,6 +58,16 @@ locals { { name = k, value = tostring(v) } ] + registry_env = [ + for k, v in merge({ + LOG_LEVEL = var.log_level + TRACECAT__APP_ENV = var.tracecat_app_env + TRACECAT__REMOTE_REPOSITORY_URL = var.remote_repository_url + TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME = var.remote_repository_package_name + }, local.tracecat_db_configs) : + { name = k, value = tostring(v) } + ] + ui_env = [ for k, v in { NEXT_PUBLIC_API_URL = local.public_api_url diff --git a/deployments/aws/ecs/variables.tf b/deployments/aws/ecs/variables.tf index a51f765f3..e99c5273c 100644 --- a/deployments/aws/ecs/variables.tf +++ b/deployments/aws/ecs/variables.tf @@ -212,6 +212,16 @@ variable "worker_memory" { default = "512" } +variable "registry_cpu" { + type = string + default = "256" +} + +variable "registry_memory" { + type = string + default = "512" +} + variable "ui_cpu" { type = string default = "256" diff --git a/deployments/aws/main.tf b/deployments/aws/main.tf index c918f194a..bef4b235b 100644 --- a/deployments/aws/main.tf +++ b/deployments/aws/main.tf @@ -78,6 +78,8 @@ module "ecs" { api_memory = var.api_memory worker_cpu = var.worker_cpu worker_memory = var.worker_memory + registry_cpu = var.registry_cpu + registry_memory = var.registry_memory ui_cpu = var.ui_cpu ui_memory = var.ui_memory temporal_cpu = var.temporal_cpu diff --git a/deployments/aws/variables.tf b/deployments/aws/variables.tf index e69d938fd..827358b7e 100644 --- a/deployments/aws/variables.tf +++ b/deployments/aws/variables.tf @@ -145,6 +145,16 @@ variable "worker_memory" { default = "512" } +variable "registry_cpu" { + type = string + default = "256" +} + +variable "registry_memory" { + type = string + default = "512" +} + variable "ui_cpu" { type = string default = "256" diff --git a/docker-compose.yml b/docker-compose.yml index 5e5a2c065..b401e1fd5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -79,7 +79,6 @@ services: # Temporal TEMPORAL__CLUSTER_URL: ${TEMPORAL__CLUSTER_URL} TEMPORAL__CLUSTER_QUEUE: ${TEMPORAL__CLUSTER_QUEUE} - command: ["python", "tracecat/dsl/worker.py"] registry: @@ -99,7 +98,7 @@ services: # Registry TRACECAT__REMOTE_REPOSITORY_URL: ${TRACECAT__REMOTE_REPOSITORY_URL} TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME: ${TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME} - entrypoint: + command: [ "python", "-m", From 85fd66f8b8835e97913fecaf6a91a35afc3c3061 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 17:19:55 -0800 Subject: [PATCH 03/13] fix(ui): Redirect SAML --- frontend/src/app/auth/oauth/callback/route.ts | 2 +- frontend/src/app/auth/saml/acs/route.ts | 15 ++++++---- frontend/src/lib/ss-utils.ts | 29 ------------------- 3 files changed, 11 insertions(+), 35 deletions(-) diff --git a/frontend/src/app/auth/oauth/callback/route.ts b/frontend/src/app/auth/oauth/callback/route.ts index 450e3a33e..ba2edf436 100644 --- a/frontend/src/app/auth/oauth/callback/route.ts +++ b/frontend/src/app/auth/oauth/callback/route.ts @@ -14,7 +14,7 @@ export const GET = async (request: NextRequest) => { const response = await fetch(url.toString()) const setCookieHeader = response.headers.get("set-cookie") - // Get redirect + // Get redirect const resp = await fetch(buildUrl("/info")) const { public_app_url } = await resp.json() console.log("Public app URL", public_app_url) diff --git a/frontend/src/app/auth/saml/acs/route.ts b/frontend/src/app/auth/saml/acs/route.ts index a6aaa5918..b4522505e 100644 --- a/frontend/src/app/auth/saml/acs/route.ts +++ b/frontend/src/app/auth/saml/acs/route.ts @@ -1,6 +1,6 @@ import { NextRequest, NextResponse } from "next/server" -import { buildUrl, getDomain } from "@/lib/ss-utils" +import { buildUrl } from "@/lib/ss-utils" /** * @param request @@ -13,9 +13,14 @@ export async function POST(request: NextRequest) { const formData = await request.formData() const samlResponse = formData.get('SAMLResponse') + // Get redirect + const resp = await fetch(buildUrl("/info")) + const { public_app_url } = await resp.json() + console.log("Public app URL", public_app_url) + if (!samlResponse) { console.error("No SAML response found in the request") - return NextResponse.redirect(new URL("/auth/error", getDomain(request))) + return NextResponse.redirect(new URL("/auth/error", public_app_url)) } // Prepare the request to the FastAPI backend @@ -31,18 +36,18 @@ export async function POST(request: NextRequest) { if (!backendResponse.ok) { console.error("Error from backend:", await backendResponse.text()) - return NextResponse.redirect(new URL("/auth/error", getDomain(request))) + return NextResponse.redirect(new URL("/auth/error", public_app_url)) } const setCookieHeader = backendResponse.headers.get("set-cookie") if (!setCookieHeader) { console.error("No set-cookie header found in response") - return NextResponse.redirect(new URL("/auth/error", getDomain(request))) + return NextResponse.redirect(new URL("/auth/error", public_app_url)) } console.log("Redirecting to / with GET") - const redirectUrl = new URL("/", getDomain(request)) + const redirectUrl = new URL("/", public_app_url) const redirectResponse = NextResponse.redirect(redirectUrl, { status: 303 // Force GET request }) diff --git a/frontend/src/lib/ss-utils.ts b/frontend/src/lib/ss-utils.ts index 742119faa..575bbf656 100644 --- a/frontend/src/lib/ss-utils.ts +++ b/frontend/src/lib/ss-utils.ts @@ -1,32 +1,3 @@ -import { NextRequest } from "next/server" - -export const getDomain = (request: NextRequest) => { - // use env variable if set - if (process.env.NEXT_PUBLIC_APP_URL) { - console.log( - "Redirecting to NEXT_PUBLIC_APP_URL:", - process.env.NEXT_PUBLIC_APP_URL - ) - return process.env.NEXT_PUBLIC_APP_URL - } - - // next, try and build domain from headers - const requestedHost = request.headers.get("X-Forwarded-Host") - const requestedPort = request.headers.get("X-Forwarded-Port") - const requestedProto = request.headers.get("X-Forwarded-Proto") - if (requestedHost) { - const url = request.nextUrl.clone() - url.host = requestedHost - url.protocol = requestedProto || url.protocol - url.port = requestedPort || url.port - console.log("Redirecting to requestedHost", url.origin) - return url.origin - } - - // finally just use whatever is in the request - return request.nextUrl.origin -} - export function buildUrl(path: string) { const url = process.env.NEXT_SERVER_API_URL || "http://api:8000" if (path.startsWith("/")) { From 22a0731c4b64680382f264edff973c8ddc558861 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 17:37:55 -0800 Subject: [PATCH 04/13] docs(fix): Key for SAML SSO --- docs/self-hosting/authentication/introduction.mdx | 6 +++--- docs/self-hosting/authentication/okta.mdx | 2 +- docs/self-hosting/deployment-options/aws-ecs.mdx | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/self-hosting/authentication/introduction.mdx b/docs/self-hosting/authentication/introduction.mdx index 8300d788c..40e684e1e 100644 --- a/docs/self-hosting/authentication/introduction.mdx +++ b/docs/self-hosting/authentication/introduction.mdx @@ -28,7 +28,7 @@ Tracecat currently supports the following authentication methods: - `basic`: Email and Password - `google_oauth`: Google OAuth -- `sso`: SAML SSO +- `saml`: SAML SSO Choose from a number of authentication methods listed below to get started. @@ -59,8 +59,8 @@ Choose from a number of authentication methods listed below to get started. ## Enable / Disable Authentication Methods You can enable / disable multiple authentication methods in the `.env` file by modifying the `TRACECAT__AUTH_TYPES` environment variable. -`TRACECAT__AUTH_TYPES` is a comma separated list of auth method keys: i.e. `basic`, `google_oauth`, `sso`. +`TRACECAT__AUTH_TYPES` is a comma separated list of auth method keys: i.e. `basic`, `google_oauth`, `saml`. ```bash -TRACECAT__AUTH_TYPES=basic,google_oauth,sso +TRACECAT__AUTH_TYPES=basic,google_oauth,saml ``` diff --git a/docs/self-hosting/authentication/okta.mdx b/docs/self-hosting/authentication/okta.mdx index 1ca5ffc60..53034d606 100644 --- a/docs/self-hosting/authentication/okta.mdx +++ b/docs/self-hosting/authentication/okta.mdx @@ -12,7 +12,7 @@ description: Learn how to authenticate into Tracecat with Okta SAML SSO. In your `.env` file, make sure you have the following value set. ```bash -TRACECAT__AUTH_TYPES=sso +TRACECAT__AUTH_TYPES=saml ``` ## Prerequisites diff --git a/docs/self-hosting/deployment-options/aws-ecs.mdx b/docs/self-hosting/deployment-options/aws-ecs.mdx index 7489df921..2873f4646 100644 --- a/docs/self-hosting/deployment-options/aws-ecs.mdx +++ b/docs/self-hosting/deployment-options/aws-ecs.mdx @@ -4,7 +4,7 @@ description: Use Terraform to deploy Tracecat into ECS Fargate. --- - This stack is meant for production use. `TRACECAT__AUTH_TYPES=google_oauth,sso` is the default configuration. + This stack is meant for production use. `TRACECAT__AUTH_TYPES=google_oauth,saml` is the default configuration. You'll need to configure [Google OAuth](/self-hosting/authentication/google) or [SAML SSO](/self-hosting/authentication/sso) to login. From d8863410bdbb595d4e55554ff2364a8d0065443c Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:03:39 -0800 Subject: [PATCH 05/13] fix: Fargate authtypes --- deployments/aws/ecs/variables.tf | 2 +- deployments/aws/variables.tf | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deployments/aws/ecs/variables.tf b/deployments/aws/ecs/variables.tf index e99c5273c..8527b4908 100644 --- a/deployments/aws/ecs/variables.tf +++ b/deployments/aws/ecs/variables.tf @@ -68,7 +68,7 @@ variable "acm_certificate_arn" { variable "auth_types" { type = string - default = "google_oauth,sso" + default = "google_oauth,saml" } diff --git a/deployments/aws/variables.tf b/deployments/aws/variables.tf index 827358b7e..b35036bbf 100644 --- a/deployments/aws/variables.tf +++ b/deployments/aws/variables.tf @@ -21,7 +21,7 @@ variable "hosted_zone_id" { variable "auth_types" { type = string - default = "google_oauth,sso" + default = "google_oauth,saml" } From f9bda85ffd69a3b33dfaad4670e5b4d2347b966e Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:49:54 -0800 Subject: [PATCH 06/13] fix(infra): MIssing env vars --- deployments/aws/ecs/ecs-api.tf | 17 ++----------- deployments/aws/ecs/ecs-registry.tf | 18 ++++++++++---- deployments/aws/ecs/ecs-worker.tf | 9 ++----- deployments/aws/ecs/locals.tf | 33 +++++++++++++++----------- deployments/aws/ecs/security_groups.tf | 8 +++++++ docker-compose.dev.yml | 3 +++ docker-compose.yml | 7 ++++-- 7 files changed, 53 insertions(+), 42 deletions(-) diff --git a/deployments/aws/ecs/ecs-api.tf b/deployments/aws/ecs/ecs-api.tf index 3c11a6d0a..7bcb0ba19 100644 --- a/deployments/aws/ecs/ecs-api.tf +++ b/deployments/aws/ecs/ecs-api.tf @@ -28,21 +28,8 @@ resource "aws_ecs_task_definition" "api_task_definition" { awslogs-stream-prefix = "api" } } - environment = concat(local.api_env, [ - { - name = "TRACECAT__DB_ENDPOINT" - value = local.core_db_hostname - }, - { - name = "TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME" - value = var.remote_repository_package_name - }, - { - name = "TRACECAT__REMOTE_REPOSITORY_URL" - value = var.remote_repository_url - } - ]) - secrets = local.tracecat_secrets + environment = local.api_env + secrets = local.tracecat_secrets dockerPullConfig = { maxAttempts = 3 backoffTime = 10 diff --git a/deployments/aws/ecs/ecs-registry.tf b/deployments/aws/ecs/ecs-registry.tf index 6f46912f1..fd5fc8bbd 100644 --- a/deployments/aws/ecs/ecs-registry.tf +++ b/deployments/aws/ecs/ecs-registry.tf @@ -20,12 +20,12 @@ resource "aws_ecs_task_definition" "registry_task_definition" { "--host", "0.0.0.0", "--port", - "8000" + "8002" ] portMappings = [ { - containerPort = 8000 - hostPort = 8000 + containerPort = 8002 + hostPort = 8002 name = "registry" appProtocol = "http" } @@ -46,6 +46,11 @@ resource "aws_ecs_task_definition" "registry_task_definition" { } } ]) + + depends_on = [ + aws_ecs_service.temporal_service, + aws_ecs_task_definition.temporal_task_definition, + ] } resource "aws_ecs_service" "tracecat_registry" { @@ -74,7 +79,7 @@ resource "aws_ecs_service" "tracecat_registry" { per_request_timeout_seconds = 120 } client_alias { - port = 8000 + port = 8002 dns_name = "registry-service" } } @@ -88,4 +93,9 @@ resource "aws_ecs_service" "tracecat_registry" { } } } + + depends_on = [ + aws_ecs_service.temporal_service, + aws_ecs_task_definition.temporal_task_definition, + ] } diff --git a/deployments/aws/ecs/ecs-worker.tf b/deployments/aws/ecs/ecs-worker.tf index 697599119..d74ff0a43 100644 --- a/deployments/aws/ecs/ecs-worker.tf +++ b/deployments/aws/ecs/ecs-worker.tf @@ -29,13 +29,8 @@ resource "aws_ecs_task_definition" "worker_task_definition" { awslogs-stream-prefix = "worker" } } - environment = concat(local.worker_env, [ - { - name = "TRACECAT__DB_ENDPOINT" - value = local.core_db_hostname - } - ]) - secrets = local.tracecat_secrets + environment = local.worker_env + secrets = local.tracecat_secrets dockerPullConfig = { maxAttempts = 3 backoffTime = 30 diff --git a/deployments/aws/ecs/locals.tf b/deployments/aws/ecs/locals.tf index 25164252d..c52ad10c1 100644 --- a/deployments/aws/ecs/locals.tf +++ b/deployments/aws/ecs/locals.tf @@ -26,20 +26,23 @@ locals { api_env = [ for k, v in merge({ - LOG_LEVEL = var.log_level - TRACECAT__API_URL = local.internal_api_url - TRACECAT__API_ROOT_PATH = "/api" - TRACECAT__APP_ENV = var.tracecat_app_env - TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url - TRACECAT__PUBLIC_APP_URL = local.public_app_url - TRACECAT__ALLOW_ORIGINS = local.allow_origins - TRACECAT__AUTH_TYPES = var.auth_types - TRACECAT__AUTH_ALLOWED_DOMAINS = var.auth_allowed_domains - TEMPORAL__CLUSTER_URL = local.temporal_cluster_url - TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue - TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout - SAML_SP_ACS_URL = local.saml_acs_url - RUN_MIGRATIONS = "true" + LOG_LEVEL = var.log_level + TRACECAT__API_URL = local.internal_api_url + TRACECAT__API_ROOT_PATH = "/api" + TRACECAT__APP_ENV = var.tracecat_app_env + TRACECAT__DB_ENDPOINT = local.core_db_hostname + TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url + TRACECAT__PUBLIC_APP_URL = local.public_app_url + TRACECAT__ALLOW_ORIGINS = local.allow_origins + TRACECAT__AUTH_TYPES = var.auth_types + TRACECAT__AUTH_ALLOWED_DOMAINS = var.auth_allowed_domains + TEMPORAL__CLUSTER_URL = local.temporal_cluster_url + TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue + TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout + SAML_SP_ACS_URL = local.saml_acs_url + TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME = var.remote_repository_package_name + TRACECAT__REMOTE_REPOSITORY_URL = var.remote_repository_url + RUN_MIGRATIONS = "true" }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] @@ -50,6 +53,7 @@ locals { TRACECAT__API_URL = local.internal_api_url TRACECAT__API_ROOT_PATH = "/api" TRACECAT__APP_ENV = var.tracecat_app_env + TRACECAT__DB_ENDPOINT = local.core_db_hostname TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url TEMPORAL__CLUSTER_URL = local.temporal_cluster_url TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue @@ -62,6 +66,7 @@ locals { for k, v in merge({ LOG_LEVEL = var.log_level TRACECAT__APP_ENV = var.tracecat_app_env + TRACECAT__DB_ENDPOINT = local.core_db_hostname TRACECAT__REMOTE_REPOSITORY_URL = var.remote_repository_url TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME = var.remote_repository_package_name }, local.tracecat_db_configs) : diff --git a/deployments/aws/ecs/security_groups.tf b/deployments/aws/ecs/security_groups.tf index 30bbdc22c..392578ab3 100644 --- a/deployments/aws/ecs/security_groups.tf +++ b/deployments/aws/ecs/security_groups.tf @@ -66,6 +66,14 @@ resource "aws_security_group" "core" { self = true } + ingress { + description = "Allow internal traffic to the Tracecat Registry service on port 8002" + from_port = 8002 + to_port = 8002 + protocol = "tcp" + self = true + } + ingress { description = "Allow internal traffic to the Tracecat UI service on port 3000" from_port = 3000 diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index d374102a6..cffea2e9b 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -107,6 +107,9 @@ services: TRACECAT__REMOTE_REPOSITORY_URL: ${TRACECAT__REMOTE_REPOSITORY_URL} TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME: ${TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME} TRACECAT__UNSAFE_DISABLE_SM_MASKING: ${TRACECAT__UNSAFE_DISABLE_SM_MASKING:-false} + # AI + TRACECAT__PRELOAD_OSS_MODELS: ${TRACECAT__PRELOAD_OSS_MODELS} + OLLAMA__API_URL: ${OLLAMA__API_URL} volumes: - ./tracecat:/app/tracecat - ./registry:/app/registry diff --git a/docker-compose.yml b/docker-compose.yml index b401e1fd5..95814a9fe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -95,9 +95,12 @@ services: TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive - # Registry - TRACECAT__REMOTE_REPOSITORY_URL: ${TRACECAT__REMOTE_REPOSITORY_URL} + # Remote registry TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME: ${TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME} + TRACECAT__REMOTE_REPOSITORY_URL: ${TRACECAT__REMOTE_REPOSITORY_URL} + # AI + TRACECAT__PRELOAD_OSS_MODELS: ${TRACECAT__PRELOAD_OSS_MODELS} + OLLAMA__API_URL: ${OLLAMA__API_URL} command: [ "python", From 9ee0154985e0500dcec3c9eac7d5a8d3209cdd7d Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:07:13 -0800 Subject: [PATCH 07/13] fix(infra): saml secrets arn policy --- deployments/aws/ecs/ecs-api.tf | 2 +- deployments/aws/ecs/ecs-registry.tf | 2 +- deployments/aws/ecs/ecs-worker.tf | 2 +- deployments/aws/ecs/iam.tf | 6 +++++- deployments/aws/ecs/secrets.tf | 6 +++--- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/deployments/aws/ecs/ecs-api.tf b/deployments/aws/ecs/ecs-api.tf index 7bcb0ba19..71c3cd8b4 100644 --- a/deployments/aws/ecs/ecs-api.tf +++ b/deployments/aws/ecs/ecs-api.tf @@ -29,7 +29,7 @@ resource "aws_ecs_task_definition" "api_task_definition" { } } environment = local.api_env - secrets = local.tracecat_secrets + secrets = local.tracecat_api_secrets dockerPullConfig = { maxAttempts = 3 backoffTime = 10 diff --git a/deployments/aws/ecs/ecs-registry.tf b/deployments/aws/ecs/ecs-registry.tf index fd5fc8bbd..978dd02bb 100644 --- a/deployments/aws/ecs/ecs-registry.tf +++ b/deployments/aws/ecs/ecs-registry.tf @@ -39,7 +39,7 @@ resource "aws_ecs_task_definition" "registry_task_definition" { } } environment = local.registry_env - secrets = local.tracecat_secrets + secrets = local.tracecat_base_secrets dockerPullConfig = { maxAttempts = 3 backoffTime = 10 diff --git a/deployments/aws/ecs/ecs-worker.tf b/deployments/aws/ecs/ecs-worker.tf index d74ff0a43..202034af0 100644 --- a/deployments/aws/ecs/ecs-worker.tf +++ b/deployments/aws/ecs/ecs-worker.tf @@ -30,7 +30,7 @@ resource "aws_ecs_task_definition" "worker_task_definition" { } } environment = local.worker_env - secrets = local.tracecat_secrets + secrets = local.tracecat_base_secrets dockerPullConfig = { maxAttempts = 3 backoffTime = 30 diff --git a/deployments/aws/ecs/iam.tf b/deployments/aws/ecs/iam.tf index 9302d96ab..cf21cd95c 100644 --- a/deployments/aws/ecs/iam.tf +++ b/deployments/aws/ecs/iam.tf @@ -49,7 +49,11 @@ resource "aws_iam_policy" "secrets_access" { var.tracecat_service_key_arn, var.tracecat_signing_secret_arn, var.oauth_client_id_arn, - var.oauth_client_secret_arn + var.oauth_client_secret_arn, + var.saml_idp_entity_id_arn, + var.saml_idp_redirect_url_arn, + var.saml_idp_certificate_arn, + var.saml_idp_metadata_url_arn, ]) } ] diff --git a/deployments/aws/ecs/secrets.tf b/deployments/aws/ecs/secrets.tf index e8d490f76..861a94ce3 100644 --- a/deployments/aws/ecs/secrets.tf +++ b/deployments/aws/ecs/secrets.tf @@ -118,7 +118,7 @@ data "aws_secretsmanager_secret_version" "temporal_db_password" { } locals { - base_secrets = [ + tracecat_base_secrets = [ { name = "TRACECAT__SERVICE_KEY" valueFrom = data.aws_secretsmanager_secret_version.tracecat_service_key.arn @@ -175,8 +175,8 @@ locals { } ] : [] - tracecat_secrets = concat( - local.base_secrets, + tracecat_api_secrets = concat( + local.tracecat_base_secrets, local.oauth_client_id_secret, local.oauth_client_secret_secret, local.saml_idp_entity_id_secret, From e22d3ba9abbd5173a058ddb1fc176c98c4f191c4 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:01:33 -0800 Subject: [PATCH 08/13] ci(infra): Missing INTERNAL_REGISTRY_URL env var in docker compose --- .env.example | 2 +- docker-compose.dev.yml | 6 +++--- docker-compose.yml | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.env.example b/.env.example index 171c775bf..8f7a81f91 100644 --- a/.env.example +++ b/.env.example @@ -7,7 +7,7 @@ PUBLIC_APP_URL=http://localhost PUBLIC_API_URL=http://localhost/api SAML_SP_ACS_URL=${PUBLIC_API_URL}/auth/saml/acs INTERNAL_API_URL=http://api:8000 - +INTERNAL_REGISTRY_URL=http://registry:8000 # -- Caddy env vars --- BASE_DOMAIN=:80 # Note: replace with your server's IP address diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index cffea2e9b..94938730d 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -71,16 +71,16 @@ services: restart: unless-stopped environment: LOG_LEVEL: ${LOG_LEVEL} - TRACECAT__API_URL: ${TRACECAT__API_URL} TRACECAT__API_ROOT_PATH: ${TRACECAT__API_ROOT_PATH} + TRACECAT__API_URL: ${TRACECAT__API_URL} TRACECAT__APP_ENV: ${TRACECAT__APP_ENV} TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive + TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} + TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive - TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} - TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} # Temporal TEMPORAL__CLUSTER_URL: ${TEMPORAL__CLUSTER_URL} TEMPORAL__CLUSTER_QUEUE: ${TEMPORAL__CLUSTER_QUEUE} diff --git a/docker-compose.yml b/docker-compose.yml index 95814a9fe..edb55d4f1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,6 +38,7 @@ services: TRACECAT__AUTH_TYPES: ${TRACECAT__AUTH_TYPES} TRACECAT__AUTH_ALLOWED_DOMAINS: ${TRACECAT__AUTH_ALLOWED_DOMAINS} TRACECAT__AUTH_MIN_PASSWORD_LENGTH: ${TRACECAT__AUTH_MIN_PASSWORD_LENGTH} + TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID} OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET} USER_AUTH_SECRET: ${USER_AUTH_SECRET} @@ -67,13 +68,14 @@ services: - temporal environment: LOG_LEVEL: ${LOG_LEVEL} - TRACECAT__API_URL: ${TRACECAT__API_URL} TRACECAT__API_ROOT_PATH: ${TRACECAT__API_ROOT_PATH} + TRACECAT__API_URL: ${TRACECAT__API_URL} TRACECAT__APP_ENV: ${TRACECAT__APP_ENV} TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} + TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive # Temporal From 72ea5e4c057523335782bf5ee285ab5dcfde9055 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:04:31 -0800 Subject: [PATCH 09/13] ci(infra): Add registry url to terraform fargate --- deployments/aws/ecs/ecs-registry.tf | 8 ++++---- deployments/aws/ecs/locals.tf | 25 ++++++++++++++----------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/deployments/aws/ecs/ecs-registry.tf b/deployments/aws/ecs/ecs-registry.tf index 978dd02bb..37a8a5497 100644 --- a/deployments/aws/ecs/ecs-registry.tf +++ b/deployments/aws/ecs/ecs-registry.tf @@ -48,8 +48,8 @@ resource "aws_ecs_task_definition" "registry_task_definition" { ]) depends_on = [ - aws_ecs_service.temporal_service, - aws_ecs_task_definition.temporal_task_definition, + aws_ecs_service.tracecat_api, + aws_ecs_service.tracecat_worker ] } @@ -95,7 +95,7 @@ resource "aws_ecs_service" "tracecat_registry" { } depends_on = [ - aws_ecs_service.temporal_service, - aws_ecs_task_definition.temporal_task_definition, + aws_ecs_service.tracecat_api, + aws_ecs_service.tracecat_worker ] } diff --git a/deployments/aws/ecs/locals.tf b/deployments/aws/ecs/locals.tf index c52ad10c1..d8285bed0 100644 --- a/deployments/aws/ecs/locals.tf +++ b/deployments/aws/ecs/locals.tf @@ -10,7 +10,8 @@ locals { public_app_url = "https://${var.domain_name}" public_api_url = "https://${var.domain_name}/api" saml_acs_url = "https://${var.domain_name}/api/auth/saml/acs" - internal_api_url = "http://api-service:8000" # Service connect DNS name + internal_api_url = "http://api-service:8000" # Service connect DNS name + internal_registry_url = "http://registry-service:8002" # Service connect DNS name temporal_cluster_url = "temporal-service:7233" temporal_cluster_queue = "tracecat-task-queue" allow_origins = "${var.domain_name},http://ui-service:3000" # Allow api service and public app to access the API @@ -27,22 +28,23 @@ locals { api_env = [ for k, v in merge({ LOG_LEVEL = var.log_level - TRACECAT__API_URL = local.internal_api_url + RUN_MIGRATIONS = "true" + SAML_SP_ACS_URL = local.saml_acs_url + TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout + TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue + TEMPORAL__CLUSTER_URL = local.temporal_cluster_url + TRACECAT__ALLOW_ORIGINS = local.allow_origins TRACECAT__API_ROOT_PATH = "/api" + TRACECAT__API_URL = local.internal_api_url TRACECAT__APP_ENV = var.tracecat_app_env + TRACECAT__AUTH_ALLOWED_DOMAINS = var.auth_allowed_domains + TRACECAT__AUTH_TYPES = var.auth_types TRACECAT__DB_ENDPOINT = local.core_db_hostname - TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url TRACECAT__PUBLIC_APP_URL = local.public_app_url - TRACECAT__ALLOW_ORIGINS = local.allow_origins - TRACECAT__AUTH_TYPES = var.auth_types - TRACECAT__AUTH_ALLOWED_DOMAINS = var.auth_allowed_domains - TEMPORAL__CLUSTER_URL = local.temporal_cluster_url - TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue - TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout - SAML_SP_ACS_URL = local.saml_acs_url + TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME = var.remote_repository_package_name TRACECAT__REMOTE_REPOSITORY_URL = var.remote_repository_url - RUN_MIGRATIONS = "true" + TRACECAT__REGISTRY_URL = local.internal_registry_url }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] @@ -58,6 +60,7 @@ locals { TEMPORAL__CLUSTER_URL = local.temporal_cluster_url TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout + TRACECAT__REGISTRY_URL = local.internal_registry_url }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] From 9cbf3d5acc562d05c82b5814dee23072e4bebedd Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:21:15 -0800 Subject: [PATCH 10/13] ci(infra): API depends on registry --- deployments/aws/ecs/ecs-api.tf | 5 +---- deployments/aws/ecs/ecs-registry.tf | 10 ---------- deployments/aws/ecs/ecs-worker.tf | 8 +------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/deployments/aws/ecs/ecs-api.tf b/deployments/aws/ecs/ecs-api.tf index 71c3cd8b4..9e471ee0d 100644 --- a/deployments/aws/ecs/ecs-api.tf +++ b/deployments/aws/ecs/ecs-api.tf @@ -36,10 +36,6 @@ resource "aws_ecs_task_definition" "api_task_definition" { } } ]) - - depends_on = [ - aws_ecs_service.temporal_service - ] } resource "aws_ecs_service" "tracecat_api" { @@ -85,6 +81,7 @@ resource "aws_ecs_service" "tracecat_api" { depends_on = [ aws_ecs_service.temporal_service, + aws_ecs_service.tracecat_registry ] } diff --git a/deployments/aws/ecs/ecs-registry.tf b/deployments/aws/ecs/ecs-registry.tf index 37a8a5497..098070233 100644 --- a/deployments/aws/ecs/ecs-registry.tf +++ b/deployments/aws/ecs/ecs-registry.tf @@ -46,11 +46,6 @@ resource "aws_ecs_task_definition" "registry_task_definition" { } } ]) - - depends_on = [ - aws_ecs_service.tracecat_api, - aws_ecs_service.tracecat_worker - ] } resource "aws_ecs_service" "tracecat_registry" { @@ -93,9 +88,4 @@ resource "aws_ecs_service" "tracecat_registry" { } } } - - depends_on = [ - aws_ecs_service.tracecat_api, - aws_ecs_service.tracecat_worker - ] } diff --git a/deployments/aws/ecs/ecs-worker.tf b/deployments/aws/ecs/ecs-worker.tf index 202034af0..2e52469c9 100644 --- a/deployments/aws/ecs/ecs-worker.tf +++ b/deployments/aws/ecs/ecs-worker.tf @@ -37,11 +37,6 @@ resource "aws_ecs_task_definition" "worker_task_definition" { } } ]) - - depends_on = [ - aws_ecs_service.temporal_service, - aws_ecs_task_definition.temporal_task_definition, - ] } resource "aws_ecs_service" "tracecat_worker" { @@ -86,7 +81,6 @@ resource "aws_ecs_service" "tracecat_worker" { } depends_on = [ - aws_ecs_service.temporal_service, - aws_ecs_service.tracecat_api, + aws_ecs_service.temporal_service ] } From 021be4d46083333ae25c7251acd1232598e18889 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Sat, 7 Dec 2024 08:20:50 +0000 Subject: [PATCH 11/13] refactor(engine): Move registry management endpoints into api service + restructure registry as executor service (#590) --- .env.example | 9 +- .github/workflows/test-python.yml | 7 +- Caddyfile | 4 +- deployments/aws/ecs/ecs-api.tf | 3 +- deployments/aws/ecs/ecs-caddy.tf | 3 +- .../ecs/{ecs-registry.tf => ecs-executor.tf} | 36 +- deployments/aws/ecs/ecs-ui.tf | 1 + deployments/aws/ecs/ecs-worker.tf | 3 +- deployments/aws/ecs/iam.tf | 6 +- deployments/aws/ecs/locals.tf | 12 +- deployments/aws/ecs/security_groups.tf | 19 +- deployments/aws/ecs/variables.tf | 4 +- deployments/aws/main.tf | 4 +- deployments/aws/variables.tf | 4 +- docker-compose.dev.yml | 18 +- docker-compose.yml | 12 +- frontend/src/client/schemas.gen.ts | 154 +------- frontend/src/client/services.gen.ts | 276 ++++++--------- frontend/src/client/types.gen.ts | 331 ++++++------------ .../components/executions/event-details.tsx | 14 +- tests/conftest.py | 7 +- tests/unit/test_workflows.py | 2 +- tracecat/api/app.py | 6 + tracecat/api/{registry.py => executor.py} | 31 +- tracecat/config.py | 15 +- tracecat/db/schemas.py | 2 +- tracecat/identifiers/__init__.py | 2 +- tracecat/registry/actions/router.py | 89 ----- tracecat/registry/actions/service.py | 8 + tracecat/registry/client.py | 80 ++++- tracecat/registry/constants.py | 4 +- tracecat/registry/executor.py | 123 ++++++- tracecat/workflow/executions/service.py | 10 + 33 files changed, 565 insertions(+), 734 deletions(-) rename deployments/aws/ecs/{ecs-registry.tf => ecs-executor.tf} (68%) rename tracecat/api/{registry.py => executor.py} (64%) diff --git a/.env.example b/.env.example index 8f7a81f91..84291d209 100644 --- a/.env.example +++ b/.env.example @@ -7,7 +7,7 @@ PUBLIC_APP_URL=http://localhost PUBLIC_API_URL=http://localhost/api SAML_SP_ACS_URL=${PUBLIC_API_URL}/auth/saml/acs INTERNAL_API_URL=http://api:8000 -INTERNAL_REGISTRY_URL=http://registry:8000 +INTERNAL_EXECUTOR_URL=http://executor:8000 # -- Caddy env vars --- BASE_DOMAIN=:80 # Note: replace with your server's IP address @@ -29,18 +29,17 @@ TRACECAT__SIGNING_SECRET=your-tracecat-signing-secret TRACECAT__API_URL=${INTERNAL_API_URL} # Root path to deal with extra path prefix behind the reverse proxy TRACECAT__API_ROOT_PATH=/api -# Public Runner URL +# This the public URL for the frontend +TRACECAT__PUBLIC_APP_URL=${PUBLIC_APP_URL} # This is the public URL for incoming webhooks # If you wish to expose your webhooks to the internet, you can use a tunneling service like ngrok. # If using ngrok, run `ngrok http --domain=INSERT_STATIC_NGROK_DOMAIN_HERE 8001` # to start ngrok and update this with the forwarding URL -TRACECAT__PUBLIC_RUNNER_URL=${PUBLIC_API_URL} +TRACECAT__PUBLIC_API_URL=${PUBLIC_API_URL} # CORS (comman separated string of allowed origins) TRACECAT__ALLOW_ORIGINS=http://localhost:3000,${PUBLIC_APP_URL} # Postgres SSL model TRACECAT__DB_SSLMODE=disable -TRACECAT__PUBLIC_APP_URL=${PUBLIC_APP_URL} -TRACECAT__PUBLIC_API_URL=${PUBLIC_API_URL} # --- Postgres --- TRACECAT__POSTGRES_USER=postgres diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 3b29154f5..0bacf0e1e 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -10,7 +10,7 @@ on: - pyproject.toml - .github/workflows/test-python.yml pull_request: - branches: ["main"] + branches: ["main", "staging"] paths: - tracecat/** - registry/** @@ -21,8 +21,7 @@ on: inputs: git-ref: description: "Git Ref (Optional)" - required: false - default: "main" + required: true permissions: contents: read @@ -126,7 +125,7 @@ jobs: - name: Start Docker services env: TRACECAT__UNSAFE_DISABLE_SM_MASKING: "true" - run: docker compose -f docker-compose.dev.yml up --build --no-deps -d api worker registry postgres_db caddy + run: docker compose -f docker-compose.dev.yml up --build --no-deps -d api worker executor postgres_db caddy - name: Install dependencies run: | diff --git a/Caddyfile b/Caddyfile index 16fe5f3bb..bf9fb62c3 100644 --- a/Caddyfile +++ b/Caddyfile @@ -1,7 +1,7 @@ {$BASE_DOMAIN} { bind {$ADDRESS} # Binds to all available network interfaces if not specified - handle_path /api/registry* { - reverse_proxy http://registry:8000 + handle_path /api/executor* { + reverse_proxy http://executor:8000 } handle_path /api* { reverse_proxy http://api:8000 diff --git a/deployments/aws/ecs/ecs-api.tf b/deployments/aws/ecs/ecs-api.tf index 9e471ee0d..31f5521b6 100644 --- a/deployments/aws/ecs/ecs-api.tf +++ b/deployments/aws/ecs/ecs-api.tf @@ -50,6 +50,7 @@ resource "aws_ecs_service" "tracecat_api" { subnets = var.private_subnet_ids security_groups = [ aws_security_group.core.id, + aws_security_group.caddy.id, aws_security_group.core_db.id, ] } @@ -81,7 +82,7 @@ resource "aws_ecs_service" "tracecat_api" { depends_on = [ aws_ecs_service.temporal_service, - aws_ecs_service.tracecat_registry + aws_ecs_service.tracecat_executor ] } diff --git a/deployments/aws/ecs/ecs-caddy.tf b/deployments/aws/ecs/ecs-caddy.tf index 84b9a2b02..ee7ca11c4 100644 --- a/deployments/aws/ecs/ecs-caddy.tf +++ b/deployments/aws/ecs/ecs-caddy.tf @@ -62,8 +62,7 @@ resource "aws_ecs_service" "tracecat_caddy" { network_configuration { subnets = var.private_subnet_ids security_groups = [ - aws_security_group.caddy.id, - aws_security_group.core.id + aws_security_group.caddy.id ] } diff --git a/deployments/aws/ecs/ecs-registry.tf b/deployments/aws/ecs/ecs-executor.tf similarity index 68% rename from deployments/aws/ecs/ecs-registry.tf rename to deployments/aws/ecs/ecs-executor.tf index 098070233..8677a98d7 100644 --- a/deployments/aws/ecs/ecs-registry.tf +++ b/deployments/aws/ecs/ecs-executor.tf @@ -1,32 +1,32 @@ -# ECS Task Definition for Registry Service -resource "aws_ecs_task_definition" "registry_task_definition" { - family = "TracecatRegistryTaskDefinition" +# ECS Task Definition for Executor Service +resource "aws_ecs_task_definition" "executor_task_definition" { + family = "TracecatExecutorTaskDefinition" network_mode = "awsvpc" requires_compatibilities = ["FARGATE"] - cpu = var.registry_cpu - memory = var.registry_memory + cpu = var.executor_cpu + memory = var.executor_memory execution_role_arn = aws_iam_role.worker_execution.arn task_role_arn = aws_iam_role.api_worker_task.arn container_definitions = jsonencode([ { - name = "TracecatRegistryContainer" + name = "TracecatExecutorContainer" image = "${var.tracecat_image}:${local.tracecat_image_tag}" command = [ "python", "-m", "uvicorn", - "tracecat.api.registry:app", + "tracecat.api.executor:app", "--host", "0.0.0.0", "--port", - "8002" + "8000" ] portMappings = [ { containerPort = 8002 hostPort = 8002 - name = "registry" + name = "executor" appProtocol = "http" } ] @@ -35,10 +35,10 @@ resource "aws_ecs_task_definition" "registry_task_definition" { options = { awslogs-group = aws_cloudwatch_log_group.tracecat_log_group.name awslogs-region = var.aws_region - awslogs-stream-prefix = "registry" + awslogs-stream-prefix = "executor" } } - environment = local.registry_env + environment = local.executor_env secrets = local.tracecat_base_secrets dockerPullConfig = { maxAttempts = 3 @@ -48,10 +48,10 @@ resource "aws_ecs_task_definition" "registry_task_definition" { ]) } -resource "aws_ecs_service" "tracecat_registry" { - name = "tracecat-registry" +resource "aws_ecs_service" "tracecat_executor" { + name = "tracecat-executor" cluster = aws_ecs_cluster.tracecat_cluster.id - task_definition = aws_ecs_task_definition.registry_task_definition.arn + task_definition = aws_ecs_task_definition.executor_task_definition.arn launch_type = "FARGATE" desired_count = 1 force_new_deployment = var.force_new_deployment @@ -68,14 +68,14 @@ resource "aws_ecs_service" "tracecat_registry" { enabled = true namespace = local.local_dns_namespace service { - port_name = "registry" - discovery_name = "registry-service" + port_name = "executor" + discovery_name = "executor-service" timeout { per_request_timeout_seconds = 120 } client_alias { port = 8002 - dns_name = "registry-service" + dns_name = "executor-service" } } @@ -84,7 +84,7 @@ resource "aws_ecs_service" "tracecat_registry" { options = { awslogs-group = aws_cloudwatch_log_group.tracecat_log_group.name awslogs-region = var.aws_region - awslogs-stream-prefix = "service-connect-registry" + awslogs-stream-prefix = "service-connect-executor" } } } diff --git a/deployments/aws/ecs/ecs-ui.tf b/deployments/aws/ecs/ecs-ui.tf index 9af885072..c077b9f91 100644 --- a/deployments/aws/ecs/ecs-ui.tf +++ b/deployments/aws/ecs/ecs-ui.tf @@ -48,6 +48,7 @@ resource "aws_ecs_service" "tracecat_ui" { subnets = var.private_subnet_ids security_groups = [ aws_security_group.core.id, + aws_security_group.caddy.id ] } diff --git a/deployments/aws/ecs/ecs-worker.tf b/deployments/aws/ecs/ecs-worker.tf index 2e52469c9..d30807e1e 100644 --- a/deployments/aws/ecs/ecs-worker.tf +++ b/deployments/aws/ecs/ecs-worker.tf @@ -81,6 +81,7 @@ resource "aws_ecs_service" "tracecat_worker" { } depends_on = [ - aws_ecs_service.temporal_service + aws_ecs_service.temporal_service, + aws_ecs_service.tracecat_executor ] } diff --git a/deployments/aws/ecs/iam.tf b/deployments/aws/ecs/iam.tf index cf21cd95c..f0deae084 100644 --- a/deployments/aws/ecs/iam.tf +++ b/deployments/aws/ecs/iam.tf @@ -114,9 +114,9 @@ resource "aws_iam_role_policy_attachment" "worker_execution_secrets" { role = aws_iam_role.worker_execution.name } -# Registry execution role -resource "aws_iam_role" "registry_execution" { - name = "TracecatRegistryExecutionRole" +# Executor execution role +resource "aws_iam_role" "executor_execution" { + name = "TracecatExecutorExecutionRole" assume_role_policy = data.aws_iam_policy_document.assume_role.json } diff --git a/deployments/aws/ecs/locals.tf b/deployments/aws/ecs/locals.tf index d8285bed0..0669f94fc 100644 --- a/deployments/aws/ecs/locals.tf +++ b/deployments/aws/ecs/locals.tf @@ -11,7 +11,7 @@ locals { public_api_url = "https://${var.domain_name}/api" saml_acs_url = "https://${var.domain_name}/api/auth/saml/acs" internal_api_url = "http://api-service:8000" # Service connect DNS name - internal_registry_url = "http://registry-service:8002" # Service connect DNS name + internal_executor_url = "http://executor-service:8002" # Service connect DNS name temporal_cluster_url = "temporal-service:7233" temporal_cluster_queue = "tracecat-task-queue" allow_origins = "${var.domain_name},http://ui-service:3000" # Allow api service and public app to access the API @@ -41,10 +41,10 @@ locals { TRACECAT__AUTH_TYPES = var.auth_types TRACECAT__DB_ENDPOINT = local.core_db_hostname TRACECAT__PUBLIC_APP_URL = local.public_app_url - TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url + TRACECAT__PUBLIC_API_URL = local.public_api_url TRACECAT__REMOTE_REPOSITORY_PACKAGE_NAME = var.remote_repository_package_name TRACECAT__REMOTE_REPOSITORY_URL = var.remote_repository_url - TRACECAT__REGISTRY_URL = local.internal_registry_url + TRACECAT__EXECUTOR_URL = local.internal_executor_url }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] @@ -56,16 +56,16 @@ locals { TRACECAT__API_ROOT_PATH = "/api" TRACECAT__APP_ENV = var.tracecat_app_env TRACECAT__DB_ENDPOINT = local.core_db_hostname - TRACECAT__PUBLIC_RUNNER_URL = local.public_api_url + TRACECAT__PUBLIC_API_URL = local.public_api_url TEMPORAL__CLUSTER_URL = local.temporal_cluster_url TEMPORAL__CLUSTER_QUEUE = local.temporal_cluster_queue TEMPORAL__CLIENT_RPC_TIMEOUT = var.temporal_client_rpc_timeout - TRACECAT__REGISTRY_URL = local.internal_registry_url + TRACECAT__EXECUTOR_URL = local.internal_executor_url }, local.tracecat_db_configs) : { name = k, value = tostring(v) } ] - registry_env = [ + executor_env = [ for k, v in merge({ LOG_LEVEL = var.log_level TRACECAT__APP_ENV = var.tracecat_app_env diff --git a/deployments/aws/ecs/security_groups.tf b/deployments/aws/ecs/security_groups.tf index 392578ab3..f8ef96e91 100644 --- a/deployments/aws/ecs/security_groups.tf +++ b/deployments/aws/ecs/security_groups.tf @@ -31,12 +31,29 @@ resource "aws_security_group" "caddy" { vpc_id = var.vpc_id ingress { + description = "Allow inbound access from ALB to port 80 (Caddy) only" protocol = "tcp" from_port = 80 to_port = 80 security_groups = [aws_security_group.alb.id] } + ingress { + description = "Allow Caddy to forward traffic to API service only" + protocol = "tcp" + from_port = 8000 + to_port = 8000 + self = true + } + + ingress { + description = "Allow Caddy to forward traffic to UI service only" + protocol = "tcp" + from_port = 3000 + to_port = 3000 + self = true + } + egress { protocol = "-1" from_port = 0 @@ -67,7 +84,7 @@ resource "aws_security_group" "core" { } ingress { - description = "Allow internal traffic to the Tracecat Registry service on port 8002" + description = "Allow internal traffic to the Tracecat Executor service on port 8000" from_port = 8002 to_port = 8002 protocol = "tcp" diff --git a/deployments/aws/ecs/variables.tf b/deployments/aws/ecs/variables.tf index 8527b4908..07dd05724 100644 --- a/deployments/aws/ecs/variables.tf +++ b/deployments/aws/ecs/variables.tf @@ -212,12 +212,12 @@ variable "worker_memory" { default = "512" } -variable "registry_cpu" { +variable "executor_cpu" { type = string default = "256" } -variable "registry_memory" { +variable "executor_memory" { type = string default = "512" } diff --git a/deployments/aws/main.tf b/deployments/aws/main.tf index bef4b235b..1f6543a14 100644 --- a/deployments/aws/main.tf +++ b/deployments/aws/main.tf @@ -78,8 +78,8 @@ module "ecs" { api_memory = var.api_memory worker_cpu = var.worker_cpu worker_memory = var.worker_memory - registry_cpu = var.registry_cpu - registry_memory = var.registry_memory + executor_cpu = var.executor_cpu + executor_memory = var.executor_memory ui_cpu = var.ui_cpu ui_memory = var.ui_memory temporal_cpu = var.temporal_cpu diff --git a/deployments/aws/variables.tf b/deployments/aws/variables.tf index b35036bbf..37875e6b7 100644 --- a/deployments/aws/variables.tf +++ b/deployments/aws/variables.tf @@ -145,12 +145,12 @@ variable "worker_memory" { default = "512" } -variable "registry_cpu" { +variable "executor_cpu" { type = string default = "256" } -variable "registry_memory" { +variable "executor_memory" { type = string default = "512" } diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 94938730d..ce96f42be 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -28,7 +28,7 @@ services: TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive - TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} + TRACECAT__PUBLIC_API_URL: ${TRACECAT__PUBLIC_API_URL} TRACECAT__PUBLIC_APP_URL: ${TRACECAT__PUBLIC_APP_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive @@ -36,7 +36,7 @@ services: TRACECAT__AUTH_TYPES: ${TRACECAT__AUTH_TYPES} TRACECAT__AUTH_ALLOWED_DOMAINS: ${TRACECAT__AUTH_ALLOWED_DOMAINS} TRACECAT__AUTH_MIN_PASSWORD_LENGTH: ${TRACECAT__AUTH_MIN_PASSWORD_LENGTH} - TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} + TRACECAT__EXECUTOR_URL: ${INTERNAL_EXECUTOR_URL} OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID} OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET} USER_AUTH_SECRET: ${USER_AUTH_SECRET} @@ -63,6 +63,7 @@ services: - ./alembic:/app/alembic depends_on: - ollama + - executor worker: build: @@ -77,8 +78,8 @@ services: TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive - TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} - TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} + TRACECAT__PUBLIC_API_URL: ${TRACECAT__PUBLIC_API_URL} + TRACECAT__EXECUTOR_URL: ${INTERNAL_EXECUTOR_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive # Temporal @@ -87,9 +88,9 @@ services: volumes: - ./tracecat:/app/tracecat - ./registry:/app/registry - entrypoint: ["python", "tracecat/dsl/worker.py"] + command: ["python", "tracecat/dsl/worker.py"] - registry: + executor: build: context: . dockerfile: Dockerfile.dev @@ -112,13 +113,12 @@ services: OLLAMA__API_URL: ${OLLAMA__API_URL} volumes: - ./tracecat:/app/tracecat - - ./registry:/app/registry - entrypoint: + command: [ "python", "-m", "uvicorn", - "tracecat.api.registry:app", + "tracecat.api.executor:app", "--host", "0.0.0.0", "--port", diff --git a/docker-compose.yml b/docker-compose.yml index edb55d4f1..58c6891a6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,7 +30,7 @@ services: TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive - TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} + TRACECAT__PUBLIC_API_URL: ${TRACECAT__PUBLIC_API_URL} TRACECAT__PUBLIC_APP_URL: ${TRACECAT__PUBLIC_APP_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive @@ -38,7 +38,7 @@ services: TRACECAT__AUTH_TYPES: ${TRACECAT__AUTH_TYPES} TRACECAT__AUTH_ALLOWED_DOMAINS: ${TRACECAT__AUTH_ALLOWED_DOMAINS} TRACECAT__AUTH_MIN_PASSWORD_LENGTH: ${TRACECAT__AUTH_MIN_PASSWORD_LENGTH} - TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} + TRACECAT__EXECUTOR_URL: ${INTERNAL_EXECUTOR_URL} OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID} OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET} USER_AUTH_SECRET: ${USER_AUTH_SECRET} @@ -74,8 +74,8 @@ services: TRACECAT__DB_ENCRYPTION_KEY: ${TRACECAT__DB_ENCRYPTION_KEY} # Sensitive TRACECAT__DB_SSLMODE: ${TRACECAT__DB_SSLMODE} TRACECAT__DB_URI: ${TRACECAT__DB_URI} # Sensitive - TRACECAT__PUBLIC_RUNNER_URL: ${TRACECAT__PUBLIC_RUNNER_URL} - TRACECAT__REGISTRY_URL: ${INTERNAL_REGISTRY_URL} + TRACECAT__PUBLIC_API_URL: ${TRACECAT__PUBLIC_API_URL} + TRACECAT__EXECUTOR_URL: ${INTERNAL_EXECUTOR_URL} TRACECAT__SERVICE_KEY: ${TRACECAT__SERVICE_KEY} # Sensitive TRACECAT__SIGNING_SECRET: ${TRACECAT__SIGNING_SECRET} # Sensitive # Temporal @@ -83,7 +83,7 @@ services: TEMPORAL__CLUSTER_QUEUE: ${TEMPORAL__CLUSTER_QUEUE} command: ["python", "tracecat/dsl/worker.py"] - registry: + executor: image: ghcr.io/tracecathq/tracecat:${TRACECAT__IMAGE_TAG:-0.16.0} restart: unless-stopped networks: @@ -108,7 +108,7 @@ services: "python", "-m", "uvicorn", - "tracecat.api.registry:app", + "tracecat.api.executor:app", "--host", "0.0.0.0", "--port", diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index 79f56de94..820a48ae8 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -173,110 +173,7 @@ export const $ActionRetryPolicy = { title: 'ActionRetryPolicy' } as const; -export const $ActionStatement_Input = { - properties: { - id: { - anyOf: [ - { - type: 'string' - }, - { - type: 'null' - } - ], - title: 'Id', - description: 'The action ID. If this is populated means there is a corresponding actionin the database `Action` table.' - }, - ref: { - type: 'string', - pattern: '^[a-z0-9_]+$', - title: 'Ref', - description: 'Unique reference for the task' - }, - description: { - type: 'string', - title: 'Description', - default: '' - }, - action: { - type: 'string', - pattern: '^[a-z0-9_.]+$', - title: 'Action', - description: 'Action type. Equivalent to the UDF key.' - }, - args: { - type: 'object', - title: 'Args', - description: 'Arguments for the action' - }, - depends_on: { - items: { - type: 'string' - }, - type: 'array', - title: 'Depends On', - description: 'Task dependencies' - }, - run_if: { - anyOf: [ - { - type: 'string' - }, - { - type: 'null' - } - ], - title: 'Run If', - description: 'Condition to run the task' - }, - for_each: { - anyOf: [ - { - type: 'string' - }, - { - items: { - type: 'string' - }, - type: 'array' - }, - { - type: 'null' - } - ], - title: 'For Each', - description: 'Iterate over a list of items and run the task for each item.' - }, - retry_policy: { - allOf: [ - { - '$ref': '#/components/schemas/ActionRetryPolicy' - } - ], - description: 'Retry policy for the action.' - }, - start_delay: { - type: 'number', - title: 'Start Delay', - description: 'Delay before starting the action in seconds.', - default: 0 - }, - join_strategy: { - allOf: [ - { - '$ref': '#/components/schemas/JoinStrategy' - } - ], - description: 'The strategy to use when joining on this task. By default, all branches must complete successfully before the join task can complete.', - default: 'all' - } - }, - type: 'object', - required: ['ref', 'action'], - title: 'ActionStatement' -} as const; - -export const $ActionStatement_Output = { +export const $ActionStatement = { properties: { ref: { type: 'string', @@ -758,10 +655,15 @@ export const $DSLContext = { }, ENV: { '$ref': '#/components/schemas/DSLEnvironment' + }, + SECRETS: { + type: 'object', + title: 'Secrets' } }, type: 'object', - title: 'DSLContext' + title: 'DSLContext', + description: 'DSL Context. Contains all the context needed to execute a DSL workflow.' } as const; export const $DSLEntrypoint = { @@ -837,7 +739,7 @@ export const $DSLInput = { }, actions: { items: { - '$ref': '#/components/schemas/ActionStatement-Output' + '$ref': '#/components/schemas/ActionStatement' }, type: 'array', title: 'Actions' @@ -1118,7 +1020,7 @@ export const $EventGroup = { action_input: { anyOf: [ { - '$ref': '#/components/schemas/RunActionInput-Output' + '$ref': '#/components/schemas/RunActionInput' }, { '$ref': '#/components/schemas/DSLRunArgs' @@ -1333,7 +1235,7 @@ export const $GetWorkflowDefinitionActivityInputs = { task: { anyOf: [ { - '$ref': '#/components/schemas/ActionStatement-Output' + '$ref': '#/components/schemas/ActionStatement' }, { type: 'null' @@ -1868,18 +1770,6 @@ export const $RegistryActionUpdate = { description: 'API update model for a registered action.' } as const; -export const $RegistryActionValidate = { - properties: { - args: { - type: 'object', - title: 'Args' - } - }, - type: 'object', - required: ['args'], - title: 'RegistryActionValidate' -} as const; - export const $RegistryActionValidateResponse = { properties: { ok: { @@ -2086,7 +1976,7 @@ export const $Role = { }, service_id: { type: 'string', - enum: ['tracecat-runner', 'tracecat-api', 'tracecat-cli', 'tracecat-schedule-runner', 'tracecat-service'], + enum: ['tracecat-runner', 'tracecat-api', 'tracecat-cli', 'tracecat-schedule-runner', 'tracecat-service', 'tracecat-executor'], title: 'Service Id' } }, @@ -2119,28 +2009,10 @@ Service roles - A service's \`user_id\` is the user it's acting on behalf of. This can be None for internal services.` } as const; -export const $RunActionInput_Input = { - properties: { - task: { - '$ref': '#/components/schemas/ActionStatement-Input' - }, - exec_context: { - '$ref': '#/components/schemas/DSLContext' - }, - run_context: { - '$ref': '#/components/schemas/RunContext' - } - }, - type: 'object', - required: ['task', 'exec_context', 'run_context'], - title: 'RunActionInput', - description: 'This object contains all the information needed to execute an action.' -} as const; - -export const $RunActionInput_Output = { +export const $RunActionInput = { properties: { task: { - '$ref': '#/components/schemas/ActionStatement-Output' + '$ref': '#/components/schemas/ActionStatement' }, exec_context: { '$ref': '#/components/schemas/DSLContext' diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index e707bcc47..842c398a6 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { PublicIncomingWebhookData, PublicIncomingWebhookResponse, PublicIncomingWebhookWaitData, PublicIncomingWebhookWaitResponse, WorkspacesListWorkspacesResponse, WorkspacesCreateWorkspaceData, WorkspacesCreateWorkspaceResponse, WorkspacesSearchWorkspacesData, WorkspacesSearchWorkspacesResponse, WorkspacesGetWorkspaceData, WorkspacesGetWorkspaceResponse, WorkspacesUpdateWorkspaceData, WorkspacesUpdateWorkspaceResponse, WorkspacesDeleteWorkspaceData, WorkspacesDeleteWorkspaceResponse, WorkspacesListWorkspaceMembershipsData, WorkspacesListWorkspaceMembershipsResponse, WorkspacesCreateWorkspaceMembershipData, WorkspacesCreateWorkspaceMembershipResponse, WorkspacesGetWorkspaceMembershipData, WorkspacesGetWorkspaceMembershipResponse, WorkspacesDeleteWorkspaceMembershipData, WorkspacesDeleteWorkspaceMembershipResponse, WorkflowsListWorkflowsData, WorkflowsListWorkflowsResponse, WorkflowsCreateWorkflowData, WorkflowsCreateWorkflowResponse, WorkflowsGetWorkflowData, WorkflowsGetWorkflowResponse, WorkflowsUpdateWorkflowData, WorkflowsUpdateWorkflowResponse, WorkflowsDeleteWorkflowData, WorkflowsDeleteWorkflowResponse, WorkflowsCommitWorkflowData, WorkflowsCommitWorkflowResponse, WorkflowsExportWorkflowData, WorkflowsExportWorkflowResponse, WorkflowsGetWorkflowDefinitionData, WorkflowsGetWorkflowDefinitionResponse, WorkflowsCreateWorkflowDefinitionData, WorkflowsCreateWorkflowDefinitionResponse, TriggersCreateWebhookData, TriggersCreateWebhookResponse, TriggersGetWebhookData, TriggersGetWebhookResponse, TriggersUpdateWebhookData, TriggersUpdateWebhookResponse, WorkflowExecutionsListWorkflowExecutionsData, WorkflowExecutionsListWorkflowExecutionsResponse, WorkflowExecutionsCreateWorkflowExecutionData, WorkflowExecutionsCreateWorkflowExecutionResponse, WorkflowExecutionsGetWorkflowExecutionData, WorkflowExecutionsGetWorkflowExecutionResponse, WorkflowExecutionsListWorkflowExecutionEventHistoryData, WorkflowExecutionsListWorkflowExecutionEventHistoryResponse, WorkflowExecutionsCancelWorkflowExecutionData, WorkflowExecutionsCancelWorkflowExecutionResponse, WorkflowExecutionsTerminateWorkflowExecutionData, WorkflowExecutionsTerminateWorkflowExecutionResponse, ActionsListActionsData, ActionsListActionsResponse, ActionsCreateActionData, ActionsCreateActionResponse, ActionsGetActionData, ActionsGetActionResponse, ActionsUpdateActionData, ActionsUpdateActionResponse, ActionsDeleteActionData, ActionsDeleteActionResponse, SecretsSearchSecretsData, SecretsSearchSecretsResponse, SecretsListSecretsData, SecretsListSecretsResponse, SecretsCreateSecretData, SecretsCreateSecretResponse, SecretsGetSecretByNameData, SecretsGetSecretByNameResponse, SecretsUpdateSecretByIdData, SecretsUpdateSecretByIdResponse, SecretsDeleteSecretByIdData, SecretsDeleteSecretByIdResponse, SchedulesListSchedulesData, SchedulesListSchedulesResponse, SchedulesCreateScheduleData, SchedulesCreateScheduleResponse, SchedulesGetScheduleData, SchedulesGetScheduleResponse, SchedulesUpdateScheduleData, SchedulesUpdateScheduleResponse, SchedulesDeleteScheduleData, SchedulesDeleteScheduleResponse, SchedulesSearchSchedulesData, SchedulesSearchSchedulesResponse, UsersSearchUserData, UsersSearchUserResponse, RegistryRepositoriesSyncRegistryRepositoriesData, RegistryRepositoriesSyncRegistryRepositoriesResponse, RegistryRepositoriesListRegistryRepositoriesResponse, RegistryRepositoriesCreateRegistryRepositoryData, RegistryRepositoriesCreateRegistryRepositoryResponse, RegistryRepositoriesGetRegistryRepositoryData, RegistryRepositoriesGetRegistryRepositoryResponse, RegistryRepositoriesUpdateRegistryRepositoryData, RegistryRepositoriesUpdateRegistryRepositoryResponse, RegistryRepositoriesDeleteRegistryRepositoryData, RegistryRepositoriesDeleteRegistryRepositoryResponse, RegistryActionsListRegistryActionsResponse, RegistryActionsCreateRegistryActionData, RegistryActionsCreateRegistryActionResponse, RegistryActionsGetRegistryActionData, RegistryActionsGetRegistryActionResponse, RegistryActionsUpdateRegistryActionData, RegistryActionsUpdateRegistryActionResponse, RegistryActionsDeleteRegistryActionData, RegistryActionsDeleteRegistryActionResponse, RegistryActionsRunRegistryActionData, RegistryActionsRunRegistryActionResponse, RegistryActionsValidateRegistryActionData, RegistryActionsValidateRegistryActionResponse, OrganizationListOrgMembersResponse, OrganizationDeleteOrgMemberData, OrganizationDeleteOrgMemberResponse, OrganizationUpdateOrgMemberData, OrganizationUpdateOrgMemberResponse, OrganizationListSessionsResponse, OrganizationDeleteSessionData, OrganizationDeleteSessionResponse, EditorListFunctionsData, EditorListFunctionsResponse, EditorListActionsData, EditorListActionsResponse, UsersUsersCurrentUserResponse, UsersUsersPatchCurrentUserData, UsersUsersPatchCurrentUserResponse, UsersUsersUserData, UsersUsersUserResponse, UsersUsersPatchUserData, UsersUsersPatchUserResponse, UsersUsersDeleteUserData, UsersUsersDeleteUserResponse, AuthAuthDatabaseLoginData, AuthAuthDatabaseLoginResponse, AuthAuthDatabaseLogoutResponse, AuthRegisterRegisterData, AuthRegisterRegisterResponse, AuthResetForgotPasswordData, AuthResetForgotPasswordResponse, AuthResetResetPasswordData, AuthResetResetPasswordResponse, AuthVerifyRequestTokenData, AuthVerifyRequestTokenResponse, AuthVerifyVerifyData, AuthVerifyVerifyResponse, AuthOauthGoogleDatabaseAuthorizeData, AuthOauthGoogleDatabaseAuthorizeResponse, AuthOauthGoogleDatabaseCallbackData, AuthOauthGoogleDatabaseCallbackResponse, AuthSamlDatabaseLoginResponse, AuthSsoAcsData, AuthSsoAcsResponse, PublicCheckHealthResponse } from './types.gen'; +import type { PublicIncomingWebhookData, PublicIncomingWebhookResponse, PublicIncomingWebhookWaitData, PublicIncomingWebhookWaitResponse, WorkspacesListWorkspacesResponse, WorkspacesCreateWorkspaceData, WorkspacesCreateWorkspaceResponse, WorkspacesSearchWorkspacesData, WorkspacesSearchWorkspacesResponse, WorkspacesGetWorkspaceData, WorkspacesGetWorkspaceResponse, WorkspacesUpdateWorkspaceData, WorkspacesUpdateWorkspaceResponse, WorkspacesDeleteWorkspaceData, WorkspacesDeleteWorkspaceResponse, WorkspacesListWorkspaceMembershipsData, WorkspacesListWorkspaceMembershipsResponse, WorkspacesCreateWorkspaceMembershipData, WorkspacesCreateWorkspaceMembershipResponse, WorkspacesGetWorkspaceMembershipData, WorkspacesGetWorkspaceMembershipResponse, WorkspacesDeleteWorkspaceMembershipData, WorkspacesDeleteWorkspaceMembershipResponse, WorkflowsListWorkflowsData, WorkflowsListWorkflowsResponse, WorkflowsCreateWorkflowData, WorkflowsCreateWorkflowResponse, WorkflowsGetWorkflowData, WorkflowsGetWorkflowResponse, WorkflowsUpdateWorkflowData, WorkflowsUpdateWorkflowResponse, WorkflowsDeleteWorkflowData, WorkflowsDeleteWorkflowResponse, WorkflowsCommitWorkflowData, WorkflowsCommitWorkflowResponse, WorkflowsExportWorkflowData, WorkflowsExportWorkflowResponse, WorkflowsGetWorkflowDefinitionData, WorkflowsGetWorkflowDefinitionResponse, WorkflowsCreateWorkflowDefinitionData, WorkflowsCreateWorkflowDefinitionResponse, TriggersCreateWebhookData, TriggersCreateWebhookResponse, TriggersGetWebhookData, TriggersGetWebhookResponse, TriggersUpdateWebhookData, TriggersUpdateWebhookResponse, WorkflowExecutionsListWorkflowExecutionsData, WorkflowExecutionsListWorkflowExecutionsResponse, WorkflowExecutionsCreateWorkflowExecutionData, WorkflowExecutionsCreateWorkflowExecutionResponse, WorkflowExecutionsGetWorkflowExecutionData, WorkflowExecutionsGetWorkflowExecutionResponse, WorkflowExecutionsListWorkflowExecutionEventHistoryData, WorkflowExecutionsListWorkflowExecutionEventHistoryResponse, WorkflowExecutionsCancelWorkflowExecutionData, WorkflowExecutionsCancelWorkflowExecutionResponse, WorkflowExecutionsTerminateWorkflowExecutionData, WorkflowExecutionsTerminateWorkflowExecutionResponse, ActionsListActionsData, ActionsListActionsResponse, ActionsCreateActionData, ActionsCreateActionResponse, ActionsGetActionData, ActionsGetActionResponse, ActionsUpdateActionData, ActionsUpdateActionResponse, ActionsDeleteActionData, ActionsDeleteActionResponse, SecretsSearchSecretsData, SecretsSearchSecretsResponse, SecretsListSecretsData, SecretsListSecretsResponse, SecretsCreateSecretData, SecretsCreateSecretResponse, SecretsGetSecretByNameData, SecretsGetSecretByNameResponse, SecretsUpdateSecretByIdData, SecretsUpdateSecretByIdResponse, SecretsDeleteSecretByIdData, SecretsDeleteSecretByIdResponse, SchedulesListSchedulesData, SchedulesListSchedulesResponse, SchedulesCreateScheduleData, SchedulesCreateScheduleResponse, SchedulesGetScheduleData, SchedulesGetScheduleResponse, SchedulesUpdateScheduleData, SchedulesUpdateScheduleResponse, SchedulesDeleteScheduleData, SchedulesDeleteScheduleResponse, SchedulesSearchSchedulesData, SchedulesSearchSchedulesResponse, UsersSearchUserData, UsersSearchUserResponse, OrganizationListOrgMembersResponse, OrganizationDeleteOrgMemberData, OrganizationDeleteOrgMemberResponse, OrganizationUpdateOrgMemberData, OrganizationUpdateOrgMemberResponse, OrganizationListSessionsResponse, OrganizationDeleteSessionData, OrganizationDeleteSessionResponse, EditorListFunctionsData, EditorListFunctionsResponse, EditorListActionsData, EditorListActionsResponse, RegistryRepositoriesSyncRegistryRepositoriesData, RegistryRepositoriesSyncRegistryRepositoriesResponse, RegistryRepositoriesListRegistryRepositoriesResponse, RegistryRepositoriesCreateRegistryRepositoryData, RegistryRepositoriesCreateRegistryRepositoryResponse, RegistryRepositoriesGetRegistryRepositoryData, RegistryRepositoriesGetRegistryRepositoryResponse, RegistryRepositoriesUpdateRegistryRepositoryData, RegistryRepositoriesUpdateRegistryRepositoryResponse, RegistryRepositoriesDeleteRegistryRepositoryData, RegistryRepositoriesDeleteRegistryRepositoryResponse, RegistryActionsListRegistryActionsResponse, RegistryActionsCreateRegistryActionData, RegistryActionsCreateRegistryActionResponse, RegistryActionsGetRegistryActionData, RegistryActionsGetRegistryActionResponse, RegistryActionsUpdateRegistryActionData, RegistryActionsUpdateRegistryActionResponse, RegistryActionsDeleteRegistryActionData, RegistryActionsDeleteRegistryActionResponse, UsersUsersCurrentUserResponse, UsersUsersPatchCurrentUserData, UsersUsersPatchCurrentUserResponse, UsersUsersUserData, UsersUsersUserResponse, UsersUsersPatchUserData, UsersUsersPatchUserResponse, UsersUsersDeleteUserData, UsersUsersDeleteUserResponse, AuthAuthDatabaseLoginData, AuthAuthDatabaseLoginResponse, AuthAuthDatabaseLogoutResponse, AuthRegisterRegisterData, AuthRegisterRegisterResponse, AuthResetForgotPasswordData, AuthResetForgotPasswordResponse, AuthResetResetPasswordData, AuthResetResetPasswordResponse, AuthVerifyRequestTokenData, AuthVerifyRequestTokenResponse, AuthVerifyVerifyData, AuthVerifyVerifyResponse, AuthOauthGoogleDatabaseAuthorizeData, AuthOauthGoogleDatabaseAuthorizeResponse, AuthOauthGoogleDatabaseCallbackData, AuthOauthGoogleDatabaseCallbackResponse, AuthSamlDatabaseLoginResponse, AuthSsoAcsData, AuthSsoAcsResponse, PublicCheckHealthResponse } from './types.gen'; /** * Incoming Webhook @@ -1110,6 +1110,121 @@ export const usersSearchUser = (data: UsersSearchUserData = {}): CancelablePromi } }); }; +/** + * List Org Members + * @returns OrgMemberRead Successful Response + * @throws ApiError + */ +export const organizationListOrgMembers = (): CancelablePromise => { return __request(OpenAPI, { + method: 'GET', + url: '/organization/members' +}); }; + +/** + * Delete Org Member + * @param data The data for the request. + * @param data.userId + * @returns void Successful Response + * @throws ApiError + */ +export const organizationDeleteOrgMember = (data: OrganizationDeleteOrgMemberData): CancelablePromise => { return __request(OpenAPI, { + method: 'DELETE', + url: '/organization/members/{user_id}', + path: { + user_id: data.userId + }, + errors: { + 422: 'Validation Error' + } +}); }; + +/** + * Update Org Member + * @param data The data for the request. + * @param data.userId + * @param data.requestBody + * @returns OrgMemberRead Successful Response + * @throws ApiError + */ +export const organizationUpdateOrgMember = (data: OrganizationUpdateOrgMemberData): CancelablePromise => { return __request(OpenAPI, { + method: 'PATCH', + url: '/organization/members/{user_id}', + path: { + user_id: data.userId + }, + body: data.requestBody, + mediaType: 'application/json', + errors: { + 422: 'Validation Error' + } +}); }; + +/** + * List Sessions + * @returns SessionRead Successful Response + * @throws ApiError + */ +export const organizationListSessions = (): CancelablePromise => { return __request(OpenAPI, { + method: 'GET', + url: '/organization/sessions' +}); }; + +/** + * Delete Session + * @param data The data for the request. + * @param data.sessionId + * @returns void Successful Response + * @throws ApiError + */ +export const organizationDeleteSession = (data: OrganizationDeleteSessionData): CancelablePromise => { return __request(OpenAPI, { + method: 'DELETE', + url: '/organization/sessions/{session_id}', + path: { + session_id: data.sessionId + }, + errors: { + 422: 'Validation Error' + } +}); }; + +/** + * List Functions + * @param data The data for the request. + * @param data.workspaceId + * @returns EditorFunctionRead Successful Response + * @throws ApiError + */ +export const editorListFunctions = (data: EditorListFunctionsData): CancelablePromise => { return __request(OpenAPI, { + method: 'GET', + url: '/editor/functions', + query: { + workspace_id: data.workspaceId + }, + errors: { + 422: 'Validation Error' + } +}); }; + +/** + * List Actions + * @param data The data for the request. + * @param data.workflowId + * @param data.workspaceId + * @returns EditorActionRead Successful Response + * @throws ApiError + */ +export const editorListActions = (data: EditorListActionsData): CancelablePromise => { return __request(OpenAPI, { + method: 'GET', + url: '/editor/actions', + query: { + workflow_id: data.workflowId, + workspace_id: data.workspaceId + }, + errors: { + 422: 'Validation Error' + } +}); }; + /** * Sync Registry Repositories * Load actions from all registry repositories. @@ -1307,165 +1422,6 @@ export const registryActionsDeleteRegistryAction = (data: RegistryActionsDeleteR } }); }; -/** - * Run Registry Action - * Execute a registry action. - * @param data The data for the request. - * @param data.actionName - * @param data.requestBody - * @returns unknown Successful Response - * @throws ApiError - */ -export const registryActionsRunRegistryAction = (data: RegistryActionsRunRegistryActionData): CancelablePromise => { return __request(OpenAPI, { - method: 'POST', - url: '/registry/actions/{action_name}/execute', - path: { - action_name: data.actionName - }, - body: data.requestBody, - mediaType: 'application/json', - errors: { - 422: 'Validation Error' - } -}); }; - -/** - * Validate Registry Action - * Validate a registry action. - * @param data The data for the request. - * @param data.actionName - * @param data.requestBody - * @returns RegistryActionValidateResponse Successful Response - * @throws ApiError - */ -export const registryActionsValidateRegistryAction = (data: RegistryActionsValidateRegistryActionData): CancelablePromise => { return __request(OpenAPI, { - method: 'POST', - url: '/registry/actions/{action_name}/validate', - path: { - action_name: data.actionName - }, - body: data.requestBody, - mediaType: 'application/json', - errors: { - 422: 'Validation Error' - } -}); }; - -/** - * List Org Members - * @returns OrgMemberRead Successful Response - * @throws ApiError - */ -export const organizationListOrgMembers = (): CancelablePromise => { return __request(OpenAPI, { - method: 'GET', - url: '/organization/members' -}); }; - -/** - * Delete Org Member - * @param data The data for the request. - * @param data.userId - * @returns void Successful Response - * @throws ApiError - */ -export const organizationDeleteOrgMember = (data: OrganizationDeleteOrgMemberData): CancelablePromise => { return __request(OpenAPI, { - method: 'DELETE', - url: '/organization/members/{user_id}', - path: { - user_id: data.userId - }, - errors: { - 422: 'Validation Error' - } -}); }; - -/** - * Update Org Member - * @param data The data for the request. - * @param data.userId - * @param data.requestBody - * @returns OrgMemberRead Successful Response - * @throws ApiError - */ -export const organizationUpdateOrgMember = (data: OrganizationUpdateOrgMemberData): CancelablePromise => { return __request(OpenAPI, { - method: 'PATCH', - url: '/organization/members/{user_id}', - path: { - user_id: data.userId - }, - body: data.requestBody, - mediaType: 'application/json', - errors: { - 422: 'Validation Error' - } -}); }; - -/** - * List Sessions - * @returns SessionRead Successful Response - * @throws ApiError - */ -export const organizationListSessions = (): CancelablePromise => { return __request(OpenAPI, { - method: 'GET', - url: '/organization/sessions' -}); }; - -/** - * Delete Session - * @param data The data for the request. - * @param data.sessionId - * @returns void Successful Response - * @throws ApiError - */ -export const organizationDeleteSession = (data: OrganizationDeleteSessionData): CancelablePromise => { return __request(OpenAPI, { - method: 'DELETE', - url: '/organization/sessions/{session_id}', - path: { - session_id: data.sessionId - }, - errors: { - 422: 'Validation Error' - } -}); }; - -/** - * List Functions - * @param data The data for the request. - * @param data.workspaceId - * @returns EditorFunctionRead Successful Response - * @throws ApiError - */ -export const editorListFunctions = (data: EditorListFunctionsData): CancelablePromise => { return __request(OpenAPI, { - method: 'GET', - url: '/editor/functions', - query: { - workspace_id: data.workspaceId - }, - errors: { - 422: 'Validation Error' - } -}); }; - -/** - * List Actions - * @param data The data for the request. - * @param data.workflowId - * @param data.workspaceId - * @returns EditorActionRead Successful Response - * @throws ApiError - */ -export const editorListActions = (data: EditorListActionsData): CancelablePromise => { return __request(OpenAPI, { - method: 'GET', - url: '/editor/actions', - query: { - workflow_id: data.workflowId, - workspace_id: data.workspaceId - }, - errors: { - 422: 'Validation Error' - } -}); }; - /** * Users:Current User * @returns UserRead Successful Response diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index c1a6492e2..27186f3f0 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -56,53 +56,7 @@ export type ActionRetryPolicy = { timeout?: number; }; -export type ActionStatement_Input = { - /** - * The action ID. If this is populated means there is a corresponding actionin the database `Action` table. - */ - id?: string | null; - /** - * Unique reference for the task - */ - ref: string; - description?: string; - /** - * Action type. Equivalent to the UDF key. - */ - action: string; - /** - * Arguments for the action - */ - args?: { - [key: string]: unknown; - }; - /** - * Task dependencies - */ - depends_on?: Array<(string)>; - /** - * Condition to run the task - */ - run_if?: string | null; - /** - * Iterate over a list of items and run the task for each item. - */ - for_each?: string | Array<(string)> | null; - /** - * Retry policy for the action. - */ - retry_policy?: ActionRetryPolicy; - /** - * Delay before starting the action in seconds. - */ - start_delay?: number; - /** - * The strategy to use when joining on this task. By default, all branches must complete successfully before the join task can complete. - */ - join_strategy?: JoinStrategy; -}; - -export type ActionStatement_Output = { +export type ActionStatement = { /** * Unique reference for the task */ @@ -268,6 +222,9 @@ export type DSLConfig_Output = { timeout?: number; }; +/** + * DSL Context. Contains all the context needed to execute a DSL workflow. + */ export type DSLContext = { INPUTS?: { [key: string]: unknown; @@ -277,6 +234,9 @@ export type DSLContext = { }; TRIGGER?: JsonValue; ENV?: DSLEnvironment; + SECRETS?: { + [key: string]: unknown; + }; }; export type DSLEntrypoint = { @@ -320,7 +280,7 @@ export type DSLInput = { title: string; description: string; entrypoint: DSLEntrypoint; - actions: Array; + actions: Array; config?: DSLConfig_Output; triggers?: Array; /** @@ -400,7 +360,7 @@ export type EventGroup = { action_ref: string; action_title: string; action_description: string; - action_input: RunActionInput_Output | DSLRunArgs | GetWorkflowDefinitionActivityInputs; + action_input: RunActionInput | DSLRunArgs | GetWorkflowDefinitionActivityInputs; action_result?: unknown | null; current_attempt?: number | null; retry_policy?: ActionRetryPolicy; @@ -440,7 +400,7 @@ export type GetWorkflowDefinitionActivityInputs = { role: Role; workflow_id: string; version?: number | null; - task?: ActionStatement_Output | null; + task?: ActionStatement | null; }; export type HTTPValidationError = { @@ -657,12 +617,6 @@ export type RegistryActionUpdate = { options?: RegistryActionOptions | null; }; -export type RegistryActionValidate = { - args: { - [key: string]: unknown; - }; -}; - export type RegistryActionValidateResponse = { ok: boolean; message: string; @@ -729,27 +683,18 @@ export type Role = { workspace_id?: string | null; user_id?: string | null; access_level?: AccessLevel; - service_id: 'tracecat-runner' | 'tracecat-api' | 'tracecat-cli' | 'tracecat-schedule-runner' | 'tracecat-service'; + service_id: 'tracecat-runner' | 'tracecat-api' | 'tracecat-cli' | 'tracecat-schedule-runner' | 'tracecat-service' | 'tracecat-executor'; }; export type type2 = 'user' | 'service'; -export type service_id = 'tracecat-runner' | 'tracecat-api' | 'tracecat-cli' | 'tracecat-schedule-runner' | 'tracecat-service'; +export type service_id = 'tracecat-runner' | 'tracecat-api' | 'tracecat-cli' | 'tracecat-schedule-runner' | 'tracecat-service' | 'tracecat-executor'; /** * This object contains all the information needed to execute an action. */ -export type RunActionInput_Input = { - task: ActionStatement_Input; - exec_context: DSLContext; - run_context: RunContext; -}; - -/** - * This object contains all the information needed to execute an action. - */ -export type RunActionInput_Output = { - task: ActionStatement_Output; +export type RunActionInput = { + task: ActionStatement; exec_context: DSLContext; run_context: RunContext; }; @@ -1626,6 +1571,42 @@ export type UsersSearchUserData = { export type UsersSearchUserResponse = UserRead; +export type OrganizationListOrgMembersResponse = Array; + +export type OrganizationDeleteOrgMemberData = { + userId: string; +}; + +export type OrganizationDeleteOrgMemberResponse = void; + +export type OrganizationUpdateOrgMemberData = { + requestBody: UserUpdate; + userId: string; +}; + +export type OrganizationUpdateOrgMemberResponse = OrgMemberRead; + +export type OrganizationListSessionsResponse = Array; + +export type OrganizationDeleteSessionData = { + sessionId: string; +}; + +export type OrganizationDeleteSessionResponse = void; + +export type EditorListFunctionsData = { + workspaceId: string; +}; + +export type EditorListFunctionsResponse = Array; + +export type EditorListActionsData = { + workflowId: string; + workspaceId: string; +}; + +export type EditorListActionsResponse = Array; + export type RegistryRepositoriesSyncRegistryRepositoriesData = { /** * Origins to sync. If no origins provided, all repositories will be synced. @@ -1689,56 +1670,6 @@ export type RegistryActionsDeleteRegistryActionData = { export type RegistryActionsDeleteRegistryActionResponse = void; -export type RegistryActionsRunRegistryActionData = { - actionName: string; - requestBody: RunActionInput_Input; -}; - -export type RegistryActionsRunRegistryActionResponse = unknown; - -export type RegistryActionsValidateRegistryActionData = { - actionName: string; - requestBody: RegistryActionValidate; -}; - -export type RegistryActionsValidateRegistryActionResponse = RegistryActionValidateResponse; - -export type OrganizationListOrgMembersResponse = Array; - -export type OrganizationDeleteOrgMemberData = { - userId: string; -}; - -export type OrganizationDeleteOrgMemberResponse = void; - -export type OrganizationUpdateOrgMemberData = { - requestBody: UserUpdate; - userId: string; -}; - -export type OrganizationUpdateOrgMemberResponse = OrgMemberRead; - -export type OrganizationListSessionsResponse = Array; - -export type OrganizationDeleteSessionData = { - sessionId: string; -}; - -export type OrganizationDeleteSessionResponse = void; - -export type EditorListFunctionsData = { - workspaceId: string; -}; - -export type EditorListFunctionsResponse = Array; - -export type EditorListActionsData = { - workflowId: string; - workspaceId: string; -}; - -export type EditorListActionsResponse = Array; - export type UsersUsersCurrentUserResponse = UserRead; export type UsersUsersPatchCurrentUserData = { @@ -2507,75 +2438,57 @@ export type $OpenApiTs = { }; }; }; - '/registry/repos/sync': { - post: { - req: RegistryRepositoriesSyncRegistryRepositoriesData; - res: { - /** - * Successful Response - */ - 204: void; - /** - * Validation Error - */ - 422: HTTPValidationError; - }; - }; - }; - '/registry/repos': { + '/organization/members': { get: { res: { /** * Successful Response */ - 200: Array; + 200: Array; }; }; - post: { - req: RegistryRepositoriesCreateRegistryRepositoryData; + }; + '/organization/members/{user_id}': { + delete: { + req: OrganizationDeleteOrgMemberData; res: { /** * Successful Response */ - 201: RegistryRepositoryRead; + 204: void; /** * Validation Error */ 422: HTTPValidationError; }; }; - }; - '/registry/repos/{origin}': { - get: { - req: RegistryRepositoriesGetRegistryRepositoryData; + patch: { + req: OrganizationUpdateOrgMemberData; res: { /** * Successful Response */ - 200: RegistryRepositoryRead; + 200: OrgMemberRead; /** * Validation Error */ 422: HTTPValidationError; }; }; - patch: { - req: RegistryRepositoriesUpdateRegistryRepositoryData; + }; + '/organization/sessions': { + get: { res: { /** * Successful Response */ - 200: RegistryRepositoryRead; - /** - * Validation Error - */ - 422: HTTPValidationError; + 200: Array; }; }; }; - '/registry/repos/{id}': { + '/organization/sessions/{session_id}': { delete: { - req: RegistryRepositoriesDeleteRegistryRepositoryData; + req: OrganizationDeleteSessionData; res: { /** * Successful Response @@ -2588,22 +2501,14 @@ export type $OpenApiTs = { }; }; }; - '/registry/actions': { + '/editor/functions': { get: { + req: EditorListFunctionsData; res: { /** * Successful Response */ - 200: Array; - }; - }; - post: { - req: RegistryActionsCreateRegistryActionData; - res: { - /** - * Successful Response - */ - 201: RegistryActionRead; + 200: Array; /** * Validation Error */ @@ -2611,22 +2516,24 @@ export type $OpenApiTs = { }; }; }; - '/registry/actions/{action_name}': { + '/editor/actions': { get: { - req: RegistryActionsGetRegistryActionData; + req: EditorListActionsData; res: { /** * Successful Response */ - 200: RegistryActionRead; + 200: Array; /** * Validation Error */ 422: HTTPValidationError; }; }; - patch: { - req: RegistryActionsUpdateRegistryActionData; + }; + '/registry/repos/sync': { + post: { + req: RegistryRepositoriesSyncRegistryRepositoriesData; res: { /** * Successful Response @@ -2638,28 +2545,23 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; - delete: { - req: RegistryActionsDeleteRegistryActionData; + }; + '/registry/repos': { + get: { res: { /** * Successful Response */ - 204: void; - /** - * Validation Error - */ - 422: HTTPValidationError; + 200: Array; }; }; - }; - '/registry/actions/{action_name}/execute': { post: { - req: RegistryActionsRunRegistryActionData; + req: RegistryRepositoriesCreateRegistryRepositoryData; res: { /** * Successful Response */ - 200: unknown; + 201: RegistryRepositoryRead; /** * Validation Error */ @@ -2667,34 +2569,37 @@ export type $OpenApiTs = { }; }; }; - '/registry/actions/{action_name}/validate': { - post: { - req: RegistryActionsValidateRegistryActionData; + '/registry/repos/{origin}': { + get: { + req: RegistryRepositoriesGetRegistryRepositoryData; res: { /** * Successful Response */ - 200: RegistryActionValidateResponse; + 200: RegistryRepositoryRead; /** * Validation Error */ 422: HTTPValidationError; }; }; - }; - '/organization/members': { - get: { + patch: { + req: RegistryRepositoriesUpdateRegistryRepositoryData; res: { /** * Successful Response */ - 200: Array; + 200: RegistryRepositoryRead; + /** + * Validation Error + */ + 422: HTTPValidationError; }; }; }; - '/organization/members/{user_id}': { + '/registry/repos/{id}': { delete: { - req: OrganizationDeleteOrgMemberData; + req: RegistryRepositoriesDeleteRegistryRepositoryData; res: { /** * Successful Response @@ -2706,68 +2611,64 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; - patch: { - req: OrganizationUpdateOrgMemberData; + }; + '/registry/actions': { + get: { res: { /** * Successful Response */ - 200: OrgMemberRead; - /** - * Validation Error - */ - 422: HTTPValidationError; + 200: Array; }; }; - }; - '/organization/sessions': { - get: { + post: { + req: RegistryActionsCreateRegistryActionData; res: { /** * Successful Response */ - 200: Array; + 201: RegistryActionRead; + /** + * Validation Error + */ + 422: HTTPValidationError; }; }; }; - '/organization/sessions/{session_id}': { - delete: { - req: OrganizationDeleteSessionData; + '/registry/actions/{action_name}': { + get: { + req: RegistryActionsGetRegistryActionData; res: { /** * Successful Response */ - 204: void; + 200: RegistryActionRead; /** * Validation Error */ 422: HTTPValidationError; }; }; - }; - '/editor/functions': { - get: { - req: EditorListFunctionsData; + patch: { + req: RegistryActionsUpdateRegistryActionData; res: { /** * Successful Response */ - 200: Array; + 204: void; /** * Validation Error */ 422: HTTPValidationError; }; }; - }; - '/editor/actions': { - get: { - req: EditorListActionsData; + delete: { + req: RegistryActionsDeleteRegistryActionData; res: { /** * Successful Response */ - 200: Array; + 204: void; /** * Validation Error */ diff --git a/frontend/src/components/executions/event-details.tsx b/frontend/src/components/executions/event-details.tsx index 5bd323b35..7b3263986 100644 --- a/frontend/src/components/executions/event-details.tsx +++ b/frontend/src/components/executions/event-details.tsx @@ -1,9 +1,5 @@ import React from "react" -import { - DSLRunArgs, - EventHistoryResponse, - RunActionInput_Output, -} from "@/client" +import { DSLRunArgs, EventHistoryResponse, RunActionInput } from "@/client" import JsonView from "react18-json-view" import { @@ -139,7 +135,7 @@ export function WorkflowExecutionEventDetailView({
- +
@@ -412,7 +408,7 @@ function ActionEventGeneralInfo({ task: { depends_on, run_if, for_each }, }, }: { - input: RunActionInput_Output + input: RunActionInput }) { return (
@@ -467,12 +463,12 @@ function ActionEventGeneralInfo({ function isRunActionInput_Output( actionInput: unknown -): actionInput is RunActionInput_Output { +): actionInput is RunActionInput { return ( typeof actionInput === "object" && actionInput !== null && "task" in actionInput && - typeof (actionInput as RunActionInput_Output).task === "object" + typeof (actionInput as RunActionInput).task === "object" ) } diff --git a/tests/conftest.py b/tests/conftest.py index ccb4072c3..b0bcf711c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,8 +59,9 @@ def env_sandbox(monkeysession: pytest.MonkeyPatch): "TRACECAT__REMOTE_REPOSITORY_URL", "git+ssh://git@github.com/TracecatHQ/udfs.git", ) + # Need this for local unit tests monkeysession.setattr( - config, "TRACECAT__REGISTRY_URL", "http://localhost/api/registry" + config, "TRACECAT__EXECUTOR_URL", "http://localhost/api/executor" ) monkeysession.setenv( @@ -69,9 +70,9 @@ def env_sandbox(monkeysession: pytest.MonkeyPatch): ) # monkeysession.setenv("TRACECAT__DB_ENCRYPTION_KEY", Fernet.generate_key().decode()) monkeysession.setenv("TRACECAT__API_URL", "http://api:8000") - monkeysession.setenv("TRACECAT__REGISTRY_URL", "http://registry:8000") + # Needed for local unit tests + monkeysession.setenv("TRACECAT__EXECUTOR_URL", "http://executor:8000") monkeysession.setenv("TRACECAT__PUBLIC_API_URL", "http://localhost/api") - monkeysession.setenv("TRACECAT__PUBLIC_RUNNER_URL", "http://localhost:8001") monkeysession.setenv("TRACECAT__SERVICE_KEY", os.environ["TRACECAT__SERVICE_KEY"]) monkeysession.setenv("TRACECAT__SIGNING_SECRET", "test-signing-secret") # When launching the worker directly in a test, use localhost diff --git a/tests/unit/test_workflows.py b/tests/unit/test_workflows.py index 9f7a9eb46..1f35e4ab9 100644 --- a/tests/unit/test_workflows.py +++ b/tests/unit/test_workflows.py @@ -1601,7 +1601,7 @@ async def test_pull_based_workflow_fetches_latest_version(temporal_client, test_ "------------------------------\n" "File: /app/tracecat/registry/executor.py\n" "Function: run_action_in_pool\n" - "Line: 83" + "Line: 200" ), "type": "RegistryActionError", "expr_context": "ACTIONS", diff --git a/tracecat/api/app.py b/tracecat/api/app.py index 2832b06b7..0c678db62 100644 --- a/tracecat/api/app.py +++ b/tracecat/api/app.py @@ -12,6 +12,7 @@ from tracecat.api.common import ( custom_generate_unique_id, generic_exception_handler, + setup_registry, tracecat_exception_handler, ) from tracecat.auth.constants import AuthType @@ -29,6 +30,8 @@ from tracecat.logger import logger from tracecat.middleware import RequestLoggingMiddleware from tracecat.organization.router import router as org_router +from tracecat.registry.actions.router import router as registry_actions_router +from tracecat.registry.repositories.router import router as registry_repos_router from tracecat.secrets.router import router as secrets_router from tracecat.types.auth import AccessLevel, Role from tracecat.types.exceptions import TracecatException @@ -50,6 +53,7 @@ async def lifespan(app: FastAPI): ) async with get_async_session_context_manager() as session: await setup_defaults(session, admin_role) + await setup_registry(session, admin_role) yield @@ -139,6 +143,8 @@ def create_app(**kwargs) -> FastAPI: app.include_router(users_router) app.include_router(org_router) app.include_router(editor_router) + app.include_router(registry_repos_router) + app.include_router(registry_actions_router) app.include_router( fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", diff --git a/tracecat/api/registry.py b/tracecat/api/executor.py similarity index 64% rename from tracecat/api/registry.py rename to tracecat/api/executor.py index 796955803..8664d2b61 100644 --- a/tracecat/api/registry.py +++ b/tracecat/api/executor.py @@ -9,31 +9,19 @@ custom_generate_unique_id, generic_exception_handler, setup_oss_models, - setup_registry, tracecat_exception_handler, ) -from tracecat.db.engine import get_async_session_context_manager from tracecat.logger import logger from tracecat.middleware import RequestLoggingMiddleware -from tracecat.registry.actions.router import router as registry_actions_router -from tracecat.registry.executor import get_executor -from tracecat.registry.repositories.router import router as registry_repos_router -from tracecat.types.auth import AccessLevel, Role +from tracecat.registry.executor import get_executor, router from tracecat.types.exceptions import TracecatException @asynccontextmanager async def lifespan(app: FastAPI): - admin_role = Role( - type="service", - access_level=AccessLevel.ADMIN, - service_id="tracecat-registry", - ) - async with get_async_session_context_manager() as session: - await setup_registry(session, admin_role) await setup_oss_models() + executor = get_executor() try: - executor = get_executor() yield finally: executor.shutdown() @@ -45,20 +33,19 @@ def create_app(**kwargs) -> FastAPI: else: allow_origins = ["*"] app = FastAPI( - title="Tracecat Registry", - description="Registry action executor.", - summary="Tracecat Registry", + title="Tracecat Executor", + description="Action executor for Tracecat.", + summary="Tracecat Executor", lifespan=lifespan, default_response_class=ORJSONResponse, generate_unique_id_function=custom_generate_unique_id, - root_path="/api/registry", + root_path="/api/executor", **kwargs, ) app.logger = logger # type: ignore # Routers - app.include_router(registry_repos_router) - app.include_router(registry_actions_router) + app.include_router(router) # Exception handlers app.add_exception_handler(Exception, generic_exception_handler) @@ -75,7 +62,7 @@ def create_app(**kwargs) -> FastAPI: ) logger.info( - "Registry service started", + "Executor service started", env=config.TRACECAT__APP_ENV, origins=allow_origins, auth_types=config.TRACECAT__AUTH_TYPES, @@ -89,4 +76,4 @@ def create_app(**kwargs) -> FastAPI: @app.get("/", include_in_schema=False) def root() -> dict[str, str]: - return {"message": "Hello world. I am the registry."} + return {"message": "Hello world. I am the executor."} diff --git a/tracecat/config.py b/tracecat/config.py index 07507ef63..c23a1f3b6 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -4,23 +4,12 @@ from tracecat.auth.constants import AuthType -# === Actions Config === # -HTTP_MAX_RETRIES = 10 -LLM_MAX_RETRIES = 3 - # === Internal Services === # -TRACECAT__SCHEDULE_INTERVAL_SECONDS = os.environ.get( - "TRACECAT__SCHEDULE_INTERVAL_SECONDS", 60 -) -TRACECAT__SCHEDULE_MAX_CONNECTIONS = 6 TRACECAT__APP_ENV: Literal["development", "staging", "production"] = os.environ.get( "TRACECAT__APP_ENV", "development" ) # type: ignore TRACECAT__API_URL = os.environ.get("TRACECAT__API_URL", "http://localhost:8000") TRACECAT__API_ROOT_PATH = os.environ.get("TRACECAT__API_ROOT_PATH", "/api") -TRACECAT__PUBLIC_RUNNER_URL = os.environ.get( - "TRACECAT__PUBLIC_RUNNER_URL", "http://localhost/api" -) TRACECAT__PUBLIC_API_URL = os.environ.get( "TRACECAT__PUBLIC_API_URL", "http://localhost/api" ) @@ -32,8 +21,8 @@ "TRACECAT__DB_URI", "postgresql+psycopg://postgres:postgres@postgres_db:5432/postgres", ) -TRACECAT__REGISTRY_URL = os.environ.get( - "TRACECAT__REGISTRY_URL", "http://registry:8000" +TRACECAT__EXECUTOR_URL = os.environ.get( + "TRACECAT__EXECUTOR_URL", "http://executor:8000" ) TRACECAT__DB_NAME = os.environ.get("TRACECAT__DB_NAME") diff --git a/tracecat/db/schemas.py b/tracecat/db/schemas.py index 721adff0f..806f64362 100644 --- a/tracecat/db/schemas.py +++ b/tracecat/db/schemas.py @@ -342,7 +342,7 @@ def secret(self) -> str: @computed_field @property def url(self) -> str: - return f"{config.TRACECAT__PUBLIC_RUNNER_URL}/webhooks/{self.workflow_id}/{self.secret}" + return f"{config.TRACECAT__PUBLIC_API_URL}/webhooks/{self.workflow_id}/{self.secret}" class Schedule(Resource, table=True): diff --git a/tracecat/identifiers/__init__.py b/tracecat/identifiers/__init__.py index d883a4778..19e274f7f 100644 --- a/tracecat/identifiers/__init__.py +++ b/tracecat/identifiers/__init__.py @@ -71,7 +71,7 @@ "tracecat-cli", "tracecat-schedule-runner", "tracecat-service", - "tracecat-registry", + "tracecat-executor", ] __all__ = [ diff --git a/tracecat/registry/actions/router.py b/tracecat/registry/actions/router.py index 8af7da1be..dc7f52484 100644 --- a/tracecat/registry/actions/router.py +++ b/tracecat/registry/actions/router.py @@ -1,29 +1,19 @@ -import traceback -from typing import Any - from fastapi import APIRouter, HTTPException, status from sqlalchemy.exc import IntegrityError from tracecat.auth.credentials import RoleACL from tracecat.concurrency import GatheringTaskGroup -from tracecat.contexts import ctx_logger, ctx_role from tracecat.db.dependencies import AsyncDBSession -from tracecat.dsl.models import RunActionInput from tracecat.logger import logger -from tracecat.registry import executor from tracecat.registry.actions.models import ( RegistryActionCreate, - RegistryActionErrorInfo, RegistryActionRead, RegistryActionUpdate, - RegistryActionValidate, - RegistryActionValidateResponse, ) from tracecat.registry.actions.service import RegistryActionsService from tracecat.registry.constants import DEFAULT_REGISTRY_ORIGIN, REGISTRY_ACTIONS_PATH from tracecat.types.auth import AccessLevel, Role from tracecat.types.exceptions import RegistryError -from tracecat.validation.service import validate_registry_action_args router = APIRouter(prefix=REGISTRY_ACTIONS_PATH, tags=["registry-actions"]) @@ -154,82 +144,3 @@ async def delete_registry_action( ) # Delete the action as it's not a base action await service.delete_action(action) - - -# Registry Action Controls - - -@router.post("/{action_name}/execute") -async def run_registry_action( - *, - role: Role = RoleACL( - allow_user=False, # XXX(authz): Users cannot execute actions - allow_service=True, # Only services can execute actions - require_workspace="no", - ), - action_name: str, - action_input: RunActionInput, -) -> Any: - """Execute a registry action.""" - ref = action_input.task.ref - ctx_role.set(role) - act_logger = logger.bind(role=role, action_name=action_name, ref=ref) - ctx_logger.set(act_logger) - - act_logger.info("Starting action") - try: - return await executor.run_action_in_pool(input=action_input) - except Exception as e: - # Get the traceback info - tb = traceback.extract_tb(e.__traceback__)[-1] # Get the last frame - error_detail = RegistryActionErrorInfo( - action_name=action_name, - type=e.__class__.__name__, - message=str(e), - filename=tb.filename, - function=tb.name, - lineno=tb.lineno, - ) - act_logger.error( - "Error running action", - action_name=action_name, - type=error_detail.type, - message=error_detail.message, - filename=error_detail.filename, - function=error_detail.function, - lineno=error_detail.lineno, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=error_detail.model_dump(mode="json"), - ) from e - - -@router.post("/{action_name}/validate") -async def validate_registry_action( - *, - role: Role = RoleACL( - allow_user=False, # XXX(authz): Users cannot validate actions - allow_service=True, # Only services can validate actions - require_workspace="no", - ), - session: AsyncDBSession, - action_name: str, - params: RegistryActionValidate, -) -> RegistryActionValidateResponse: - """Validate a registry action.""" - try: - result = await validate_registry_action_args( - session=session, action_name=action_name, args=params.args - ) - - if result.status == "error": - logger.warning( - "Error validating UDF args", message=result.msg, details=result.detail - ) - return RegistryActionValidateResponse.from_validation_result(result) - except KeyError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Action {action_name!r} not found in registry", - ) from e diff --git a/tracecat/registry/actions/service.py b/tracecat/registry/actions/service.py index 07a7402c0..5648059cc 100644 --- a/tracecat/registry/actions/service.py +++ b/tracecat/registry/actions/service.py @@ -22,6 +22,7 @@ RegistryActionUpdate, model_converters, ) +from tracecat.registry.client import RegistryClient from tracecat.registry.loaders import get_bound_action_impl from tracecat.registry.repository import Repository from tracecat.types.auth import Role @@ -208,9 +209,11 @@ async def sync_actions_from_repository(self, db_repo: RegistryRepository) -> Non - For each repository, we need to reimport the packages to run decorators. (for remote this involves pulling) - Scan the repositories for implementation details/metadata and update the DB """ + # (1) Update the API's view of the repository repo = Repository(origin=db_repo.origin, role=self.role) await repo.load_from_origin() + # (2) Handle DB bookkeeping for the API's view of the repository # Perform diffing here. The expectation for this endpoint is to sync Tracecat's view of # the repository with the remote repository -- meaning any creation/updates/deletions to # actions should be propogated to the db. @@ -275,6 +278,11 @@ async def sync_actions_from_repository(self, db_repo: RegistryRepository) -> Non deleted=n_deleted, ) + # (3) Update the executor's view of the repository + self.logger.info("Syncing executor", origin=db_repo.origin) + client = RegistryClient(role=self.role) + await client.sync_executor(origin=db_repo.origin) + async def load_action_impl(self, action_name: str) -> BoundRegistryAction: """ Load the implementation for a registry action. diff --git a/tracecat/registry/client.py b/tracecat/registry/client.py index 43bc50749..c6fa04039 100644 --- a/tracecat/registry/client.py +++ b/tracecat/registry/client.py @@ -1,11 +1,18 @@ """Use this in worker to execute actions.""" -from collections.abc import Mapping +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager from json import JSONDecodeError from typing import Any, cast import httpx import orjson +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from tracecat import config from tracecat.clients import AuthenticatedServiceClient @@ -26,9 +33,11 @@ class _RegistryHTTPClient(AuthenticatedServiceClient): """Async httpx client for the registry service.""" def __init__(self, role: Role | None = None, *args: Any, **kwargs: Any) -> None: - self._registry_base_url = config.TRACECAT__REGISTRY_URL + self._registry_base_url = config.TRACECAT__EXECUTOR_URL super().__init__(role, *args, base_url=self._registry_base_url, **kwargs) - self.params = self.params.add("workspace_id", str(self.role.workspace_id)) + self.params = self.params.add( + "workspace_id", str(self.role.workspace_id) if self.role else None + ) class RegistryClient: @@ -42,6 +51,11 @@ def __init__(self, role: Role | None = None): self.role = role or ctx_role.get() self.logger = logger.bind(service="registry-client", role=self.role) + @asynccontextmanager + async def _client(self) -> AsyncIterator[_RegistryHTTPClient]: + async with _RegistryHTTPClient(self.role) as client: + yield client + """Execution""" async def call_action(self, input: RunActionInput) -> Any: @@ -75,7 +89,6 @@ async def call_action(self, input: RunActionInput) -> Any: action_type = input.task.action content = input.model_dump_json() - workspace_id = str(self.role.workspace_id) if self.role.workspace_id else None logger.debug( f"Calling action {action_type!r} with content", content=content, @@ -83,9 +96,9 @@ async def call_action(self, input: RunActionInput) -> Any: timeout=self._timeout, ) try: - async with _RegistryHTTPClient(self.role) as client: + async with self._client() as client: response = await client.post( - f"{self._actions_endpoint}/{action_type}/execute", + f"/run/{action_type}", # NOTE(perf): Maybe serialize with orjson.dumps instead headers={ "Content-Type": "application/json", @@ -93,7 +106,6 @@ async def call_action(self, input: RunActionInput) -> Any: **self.role.to_headers(), }, content=content, - params={"workspace_id": workspace_id}, timeout=self._timeout, ) response.raise_for_status() @@ -144,10 +156,9 @@ async def validate_action( """Validate an action.""" try: logger.warning("Validating action") - async with _RegistryHTTPClient(self.role) as client: + async with self._client() as client: response = await client.post( - f"{self._actions_endpoint}/{action_name}/validate", - json={"args": args}, + f"/validate/{action_name}", json={"args": args} ) response.raise_for_status() return RegistryActionValidateResponse.model_validate_json(response.content) @@ -164,6 +175,55 @@ async def validate_action( f"Unexpected error while listing registries: {str(e)}" ) from e + """Executor""" + + async def sync_executor(self, origin: str, *, max_attempts: int = 3) -> None: + """Sync the executor from the registry. + + Args: + origin: The origin of the sync request + + Raises: + RegistryError: If the sync fails after all retries + """ + + @retry( + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + httpx.HTTPStatusError, + httpx.RequestError, + httpx.TimeoutException, + httpx.ConnectError, + ) + ), + ) + async def _sync_request() -> None: + try: + async with self._client() as client: + response = await client.post("/sync", json={"origin": origin}) + response.raise_for_status() + except Exception as e: + logger.error("Error syncing executor", error=e) + raise + + try: + logger.info("Syncing executor", origin=origin) + _ = await _sync_request() + except httpx.HTTPStatusError as e: + raise RegistryError( + f"Failed to sync executor: HTTP {e.response.status_code}" + ) from e + except httpx.RequestError as e: + raise RegistryError( + f"Network error while syncing executor: {str(e)}" + ) from e + except Exception as e: + raise RegistryError( + f"Unexpected error while syncing executor: {str(e)}" + ) from e + """Registry management""" async def list_repositories(self) -> list[str]: diff --git a/tracecat/registry/constants.py b/tracecat/registry/constants.py index 25167d959..311598999 100644 --- a/tracecat/registry/constants.py +++ b/tracecat/registry/constants.py @@ -3,8 +3,8 @@ CUSTOM_REPOSITORY_ORIGIN = "custom" GITHUB_SSH_KEY_SECRET_NAME = "github-ssh-key" -REGISTRY_REPOS_PATH: str = "/repos" +REGISTRY_REPOS_PATH: str = "/registry/repos" """Base path for repository-related endpoints""" -REGISTRY_ACTIONS_PATH: str = "/actions" +REGISTRY_ACTIONS_PATH: str = "/registry/actions" """Base path for action-related endpoints""" diff --git a/tracecat/registry/executor.py b/tracecat/registry/executor.py index ffd7de38a..6eb4b8ec8 100644 --- a/tracecat/registry/executor.py +++ b/tracecat/registry/executor.py @@ -6,16 +6,21 @@ from __future__ import annotations import asyncio +import traceback from collections.abc import Iterator, Mapping from concurrent.futures import ProcessPoolExecutor from typing import Any, cast import uvloop +from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel from tracecat import config +from tracecat.auth.credentials import RoleACL from tracecat.auth.sandbox import AuthSandbox from tracecat.concurrency import GatheringTaskGroup from tracecat.contexts import ctx_logger, ctx_role, ctx_run +from tracecat.db.dependencies import AsyncDBSession from tracecat.db.engine import get_async_engine from tracecat.dsl.common import context_locator, create_default_dsl_context from tracecat.dsl.models import ( @@ -33,20 +38,132 @@ from tracecat.expressions.shared import ExprContext from tracecat.logger import logger from tracecat.parse import traverse_leaves -from tracecat.registry.actions.models import ArgsClsT, BoundRegistryAction +from tracecat.registry.actions.models import ( + ArgsClsT, + BoundRegistryAction, + RegistryActionErrorInfo, + RegistryActionValidate, + RegistryActionValidateResponse, +) from tracecat.registry.actions.service import RegistryActionsService +from tracecat.registry.repository import Repository from tracecat.secrets.common import apply_masks_object from tracecat.secrets.constants import DEFAULT_SECRETS_ENVIRONMENT from tracecat.secrets.secrets_manager import env_sandbox from tracecat.types.auth import Role -from tracecat.types.exceptions import TracecatException +from tracecat.types.exceptions import RegistryError, TracecatException +from tracecat.validation.service import validate_registry_action_args """All these methods are used in the registry executor, not on the worker""" -type ArgsT = Mapping[str, Any] +# Registry Action Controls +type ArgsT = Mapping[str, Any] _executor: ProcessPoolExecutor | None = None +router = APIRouter(tags=["executor"]) + + +class ExecutorSyncInput(BaseModel): + origin: str + + +@router.post("/sync") +async def sync_executor( + *, + role: Role = RoleACL( + allow_user=False, # XXX(authz): Users cannot sync the executor + allow_service=True, # Only services can sync the executor + require_workspace="no", + ), + input: ExecutorSyncInput, +) -> None: + """Sync the executor from the registry.""" + repo = Repository(origin=input.origin, role=role) + try: + await repo.load_from_origin() + except RegistryError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) from e + + +@router.post("/run/{action_name}") +async def run_action( + *, + role: Role = RoleACL( + allow_user=False, # XXX(authz): Users cannot execute actions + allow_service=True, # Only services can execute actions + require_workspace="no", + ), + action_name: str, + action_input: RunActionInput, +) -> Any: + """Execute a registry action.""" + ref = action_input.task.ref + ctx_role.set(role) + act_logger = logger.bind(role=role, action_name=action_name, ref=ref) + ctx_logger.set(act_logger) + + act_logger.info("Starting action") + try: + return await run_action_in_pool(input=action_input) + except Exception as e: + # Get the traceback info + tb = traceback.extract_tb(e.__traceback__)[-1] # Get the last frame + error_detail = RegistryActionErrorInfo( + action_name=action_name, + type=e.__class__.__name__, + message=str(e), + filename=tb.filename, + function=tb.name, + lineno=tb.lineno, + ) + act_logger.error( + "Error running action", + action_name=action_name, + type=error_detail.type, + message=error_detail.message, + filename=error_detail.filename, + function=error_detail.function, + lineno=error_detail.lineno, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=error_detail.model_dump(mode="json"), + ) from e + + +@router.post("/validate/{action_name}") +async def validate_action( + *, + role: Role = RoleACL( + allow_user=False, # XXX(authz): Users cannot validate actions + allow_service=True, # Only services can validate actions + require_workspace="no", + ), + session: AsyncDBSession, + action_name: str, + params: RegistryActionValidate, +) -> RegistryActionValidateResponse: + """Validate a registry action.""" + try: + result = await validate_registry_action_args( + session=session, action_name=action_name, args=params.args + ) + + if result.status == "error": + logger.warning( + "Error validating UDF args", message=result.msg, details=result.detail + ) + return RegistryActionValidateResponse.from_validation_result(result) + except KeyError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Action {action_name!r} not found in registry", + ) from e + + # We want to be able to serve a looped action # Before we send out tasks to the executor we should inspect the size of the loop # and set the right chunk size for each worker diff --git a/tracecat/workflow/executions/service.py b/tracecat/workflow/executions/service.py index 2cbe025ad..6a71fbb75 100644 --- a/tracecat/workflow/executions/service.py +++ b/tracecat/workflow/executions/service.py @@ -19,6 +19,7 @@ WorkflowHandle, WorkflowHistoryEventFilterType, ) +from temporalio.service import RPCError from tracecat import config from tracecat.contexts import ctx_role @@ -435,6 +436,15 @@ async def _dispatch_workflow( except WorkflowFailureError as e: self.logger.error(str(e), role=self.role, wf_exec_id=wf_exec_id, e=e) raise e + except RPCError as e: + self.logger.error( + f"Temporal service RPC error occurred while executing the workflow: {e}", + role=self.role, + wf_exec_id=wf_exec_id, + e=e, + ) + raise e + except Exception as e: self.logger.exception( "Unexpected workflow error", role=self.role, wf_exec_id=wf_exec_id, e=e From 3315fccff9dc13cce55bd396300e401fd122471b Mon Sep 17 00:00:00 2001 From: Chris Lo <46541035+topher-lo@users.noreply.github.com> Date: Mon, 9 Dec 2024 19:20:26 -0800 Subject: [PATCH 12/13] ci(build): Run ARM docker builds on ARM runners (#594) --- .github/workflows/build-push-images.yml | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-push-images.yml b/.github/workflows/build-push-images.yml index 0f346e42e..55e2bc088 100644 --- a/.github/workflows/build-push-images.yml +++ b/.github/workflows/build-push-images.yml @@ -12,7 +12,14 @@ permissions: jobs: push-api-to-ghcr: - runs-on: ubuntu-latest + runs-on: ${{ matrix.runner }} + strategy: + matrix: + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-arm64-latest steps: - name: Checkout repository uses: actions/checkout@v4 @@ -48,7 +55,7 @@ jobs: with: context: . push: true - platforms: linux/amd64,linux/arm64 + platforms: ${{ matrix.platform }} tags: | ${{ steps.meta.outputs.tags }} ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && 'ghcr.io/tracecathq/tracecat:latest' || '' }} @@ -57,7 +64,14 @@ jobs: cache-to: type=gha,mode=max push-ui-to-ghcr: - runs-on: ubuntu-latest + runs-on: ${{ matrix.runner }} + strategy: + matrix: + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-arm64-latest steps: - name: Checkout repository uses: actions/checkout@v4 @@ -106,7 +120,7 @@ jobs: NEXT_SERVER_API_URL=${{ env.NEXT_SERVER_API_URL }} NODE_ENV=${{ env.NODE_ENV }} push: true - platforms: linux/amd64,linux/arm64 + platforms: ${{ matrix.platform }} tags: ${{ steps.meta.outputs.tags }} ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && 'ghcr.io/tracecathq/tracecat-ui:latest' || '' }} labels: ${{ steps.meta.outputs.labels }} From fe5350e7f6f5426b123525a7db2535977fc9b810 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Mon, 9 Dec 2024 19:39:51 -0800 Subject: [PATCH 13/13] ci(infra): Executor port command 8002 not 8000 --- deployments/aws/ecs/ecs-executor.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deployments/aws/ecs/ecs-executor.tf b/deployments/aws/ecs/ecs-executor.tf index 8677a98d7..bc004b973 100644 --- a/deployments/aws/ecs/ecs-executor.tf +++ b/deployments/aws/ecs/ecs-executor.tf @@ -20,7 +20,7 @@ resource "aws_ecs_task_definition" "executor_task_definition" { "--host", "0.0.0.0", "--port", - "8000" + "8002" ] portMappings = [ {