From cc9c7c236db5711adc1dc48d6cde71e9bda44ed1 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 11:35:36 +0000 Subject: [PATCH] feat(engine): Add and update git and ssh modules (#754) --- frontend/src/client/schemas.gen.ts | 1 + frontend/src/client/types.gen.ts | 2 + tests/unit/test_registry.py | 12 +-- tracecat/api/common.py | 2 +- tracecat/dsl/scheduler.py | 4 +- tracecat/git.py | 89 +++++++++++++++-- tracecat/identifiers/__init__.py | 1 + tracecat/registry/repository.py | 2 - tracecat/settings/service.py | 2 + tracecat/ssh.py | 148 ++++++++++++++++++++++++----- tracecat/types/exceptions.py | 4 + 11 files changed, 222 insertions(+), 45 deletions(-) diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index a90201d2f..408afbb1b 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -2353,6 +2353,7 @@ export const $Role = { "tracecat-schedule-runner", "tracecat-service", "tracecat-executor", + "tracecat-bootstrap", ], title: "Service Id", }, diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 621b91943..9f4f0e949 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -781,6 +781,7 @@ export type Role = { | "tracecat-schedule-runner" | "tracecat-service" | "tracecat-executor" + | "tracecat-bootstrap" } export type type2 = "user" | "service" @@ -792,6 +793,7 @@ export type service_id = | "tracecat-schedule-runner" | "tracecat-service" | "tracecat-executor" + | "tracecat-bootstrap" /** * This object contains all the information needed to execute an action. diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index 7554f7b08..c41ec3298 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -154,17 +154,17 @@ async def test_registry_async_function_can_be_called(mock_package): host="github.com", org="org", repo="repo", - branch="main", + ref=None, ), ), - # GitHub (with branch) + # GitHub (with branch/sha) ( - "git+ssh://git@github.com/org/repo@branch", + "git+ssh://git@github.com/org/repo@branchOrSHAOrTag", GitUrl( host="github.com", org="org", repo="repo", - branch="branch", + ref="branchOrSHAOrTag", ), ), # GitLab @@ -174,7 +174,7 @@ async def test_registry_async_function_can_be_called(mock_package): host="gitlab.com", org="org", repo="repo", - branch="main", + ref=None, ), ), # GitLab (with branch) @@ -184,7 +184,7 @@ async def test_registry_async_function_can_be_called(mock_package): host="gitlab.com", org="org", repo="repo", - branch="branch", + ref="branch", ), ), ], diff --git a/tracecat/api/common.py b/tracecat/api/common.py index 9a00f0a87..8f6c7689f 100644 --- a/tracecat/api/common.py +++ b/tracecat/api/common.py @@ -28,7 +28,7 @@ def bootstrap_role(): return Role( type="service", access_level=AccessLevel.ADMIN, - service_id="tracecat-api", + service_id="tracecat-bootstrap", ) diff --git a/tracecat/dsl/scheduler.py b/tracecat/dsl/scheduler.py index 75c305d9f..9bd275f0f 100644 --- a/tracecat/dsl/scheduler.py +++ b/tracecat/dsl/scheduler.py @@ -148,7 +148,7 @@ async def _queue_tasks( ) async with asyncio.TaskGroup() as tg: for next_ref, edge_type in next_tasks: - self.logger.warning("Processing next task", ref=ref, next_ref=next_ref) + self.logger.debug("Processing next task", ref=ref, next_ref=next_ref) edge = DSLEdge(src=ref, dst=next_ref, type=edge_type) if unreachable and edge in unreachable: self._mark_edge(edge, EdgeMarker.SKIPPED) @@ -272,7 +272,7 @@ def _is_reachable(self, task: ActionStatement) -> bool: # Root nodes are always reachable return True elif n_deps == 1: - logger.warning("Task has only 1 dependency", task=task) + logger.debug("Task has only 1 dependency", task=task) # If there's only 1 dependency, the node is reachable only if the # dependency was successful ignoring the join strategy. dep_ref = task.depends_on[0] diff --git a/tracecat/git.py b/tracecat/git.py index 2d57a88d3..98f364471 100644 --- a/tracecat/git.py +++ b/tracecat/git.py @@ -1,13 +1,21 @@ import asyncio import re from dataclasses import dataclass +from typing import cast +from tracecat import config +from tracecat.contexts import ctx_role from tracecat.logger import logger +from tracecat.registry.repositories.service import RegistryReposService +from tracecat.settings.service import get_setting from tracecat.ssh import SshEnv +from tracecat.types.auth import Role +from tracecat.types.exceptions import TracecatSettingsError GIT_SSH_URL_REGEX = re.compile( - r"^git\+ssh://git@(?P[^/]+)/(?P[^/]+)/(?P[^/@]+?)(?:\.git)?(?:@(?P[^/]+))?$" + r"^git\+ssh://git@(?P[^/]+)/(?P[^/]+)/(?P[^/@]+?)(?:\.git)?(?:@(?P[^/]+))?$" ) +"""Git SSH URL with git user and optional ref.""" @dataclass @@ -15,7 +23,11 @@ class GitUrl: host: str org: str repo: str - branch: str + ref: str | None = None + + def to_url(self) -> str: + base = f"git+ssh://git@{self.host}/{self.org}/{self.repo}.git" + return f"{base}@{self.ref}" if self.ref else base async def get_git_repository_sha(repo_url: str, env: SshEnv) -> str: @@ -38,8 +50,8 @@ async def get_git_repository_sha(repo_url: str, env: SshEnv) -> str: raise RuntimeError(f"Failed to get repository SHA: {error_message}") # The output format is: "\tHEAD" - sha = stdout.decode().split()[0] - return sha + ref = stdout.decode().split()[0] + return ref except Exception as e: logger.error("Error getting repository SHA", error=str(e)) @@ -63,16 +75,73 @@ def parse_git_url(url: str, *, allowed_domains: set[str] | None = None) -> GitUr if match := GIT_SSH_URL_REGEX.match(url): host = match.group("host") + org = match.group("org") + repo = match.group("repo") + ref = match.group("ref") + + if ( + not isinstance(host, str) + or not isinstance(org, str) + or not isinstance(repo, str) + ): + raise ValueError(f"Invalid Git URL: {url}") + if allowed_domains and host not in allowed_domains: raise ValueError( f"Domain {host} not in allowed domains. Must be configured in `git_allowed_domains` organization setting." ) - return GitUrl( - host=host, - org=match.group("org"), - repo=match.group("repo"), - branch=match.group("branch") or "main", - ) + return GitUrl(host=host, org=org, repo=repo, ref=ref) raise ValueError(f"Unsupported URL format: {url}. Must be a valid Git SSH URL.") + + +async def prepare_git_url(role: Role | None = None) -> GitUrl | None: + """Construct the runtime environment + Deps: + In the new pull-model registry, the execution environment is ALL the registries + 1. Tracecat registry + 2. User's custom template registry + 3. User's custom UDF registry (github) + + Why? + Since we no longer depend on the user to push to executor, the db repos are now + the source of truth. + """ + role = role or ctx_role.get() + + # Handle the git repo + url = await get_setting( + "git_repo_url", + # TODO: Deprecate in future version + default=config.TRACECAT__REMOTE_REPOSITORY_URL, + ) + if not url or not isinstance(url, str): + logger.debug("No git URL found") + return None + + logger.debug("Runtime environment", url=url) + + allowed_domains_setting = await get_setting( + "git_allowed_domains", + # TODO: Deprecate in future version + default=config.TRACECAT__ALLOWED_GIT_DOMAINS, + ) + allowed_domains = cast(set[str], allowed_domains_setting or {"github.com"}) + + # Grab the sha + # Find the repository that has the same origin + sha = None + async with RegistryReposService.with_session(role=role) as service: + repo = await service.get_repository(origin=url) + sha = repo.commit_sha if repo else None + + try: + # Validate + git_url = parse_git_url(url, allowed_domains=allowed_domains) + except ValueError as e: + raise TracecatSettingsError( + "Invalid Git repository URL. Please provide a valid Git SSH URL (git+ssh)." + ) from e + git_url.ref = sha + return git_url diff --git a/tracecat/identifiers/__init__.py b/tracecat/identifiers/__init__.py index bb88af4d4..6e1d772f3 100644 --- a/tracecat/identifiers/__init__.py +++ b/tracecat/identifiers/__init__.py @@ -73,6 +73,7 @@ "tracecat-schedule-runner", "tracecat-service", "tracecat-executor", + "tracecat-bootstrap", ] __all__ = [ diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index 2069d7875..830c60e3a 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -263,7 +263,6 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: host = git_url.host org = git_url.org repo_name = git_url.repo - branch = git_url.branch except ValueError as e: raise RegistryError( "Invalid Git repository URL. Please provide a valid Git SSH URL (git+ssh)." @@ -283,7 +282,6 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: org=org, repo=repo_name, package_name=package_name, - ref=branch, ) cleaned_url = self.safe_remote_url(self._origin) diff --git a/tracecat/settings/service.py b/tracecat/settings/service.py index d23cd3796..1968e3626 100644 --- a/tracecat/settings/service.py +++ b/tracecat/settings/service.py @@ -10,6 +10,7 @@ from tracecat import config from tracecat.authz.controls import require_access_level +from tracecat.contexts import ctx_role from tracecat.db.schemas import OrganizationSetting from tracecat.logger import logger from tracecat.secrets.encryption import decrypt_value, encrypt_value @@ -264,6 +265,7 @@ async def get_setting( default: Any | None = None, ) -> Any | None: """Shorthand to get a setting value from the database.""" + role = role or ctx_role.get() # If we have an environment override, use it if override_val := get_setting_override(key): diff --git a/tracecat/ssh.py b/tracecat/ssh.py index 95c3123f4..8f5979e0e 100644 --- a/tracecat/ssh.py +++ b/tracecat/ssh.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os import subprocess @@ -6,10 +8,19 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING +import aiofiles import paramiko +from sqlmodel.ext.asyncio.session import AsyncSession +from tracecat.contexts import ctx_role from tracecat.logger import logger +from tracecat.secrets.service import SecretsService +from tracecat.types.auth import Role + +if TYPE_CHECKING: + from tracecat.git import GitUrl @dataclass @@ -80,8 +91,16 @@ async def temporary_ssh_agent() -> AsyncIterator[SshEnv]: logger.debug("Killed ssh-agent") -async def add_host_to_known_hosts(url: str, *, env: SshEnv) -> None: - """Add the host to the known hosts file.""" +def add_host_to_known_hosts_sync(url: str, env: SshEnv) -> None: + """Synchronously add the host to the known hosts file if not already present. + + Args: + url: The host URL to add + env: SSH environment variables + + Raises: + Exception: If ssh-keyscan fails to get the host key + """ try: # Ensure the ~/.ssh directory exists ssh_dir = Path.home() / ".ssh" @@ -89,33 +108,43 @@ async def add_host_to_known_hosts(url: str, *, env: SshEnv) -> None: known_hosts_file = ssh_dir / "known_hosts" + # Check if host already exists in known_hosts + if known_hosts_file.exists(): + with known_hosts_file.open("r") as f: + # Look for the hostname in existing entries + if any(url in line for line in f.readlines()): + logger.debug("Host already in known_hosts file", url=url) + return # Use ssh-keyscan to get the host key - process = await asyncio.create_subprocess_exec( - "ssh-keyscan", - url, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + result = subprocess.run( + ["ssh-keyscan", url], + capture_output=True, + text=True, env=env.to_dict(), + check=False, ) - stdout, stderr = await process.communicate() - if process.returncode != 0: - raise Exception(f"Failed to get host key: {stderr.decode().strip()}") + if result.returncode != 0: + raise RuntimeError(f"Failed to get host key: {result.stderr.strip()}") # Append the host key to the known_hosts file with known_hosts_file.open("a") as f: - f.write(stdout.decode()) + f.write(result.stdout) logger.info("Added host to known hosts", url=url) except Exception as e: - logger.error(f"Error adding host to known hosts: {str(e)}") + logger.error("Error adding host to known hosts", error=e) raise -async def add_ssh_key_to_agent(key_data: str, env: SshEnv) -> None: - """Add the SSH key to the agent then remove it.""" - # TODO(perf): Improve concurrency - with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_key_file: +async def add_host_to_known_hosts(url: str, *, env: SshEnv) -> None: + """Asynchronously add the host to the known hosts file.""" + return await asyncio.to_thread(add_host_to_known_hosts_sync, url, env) + + +def add_ssh_key_to_agent_sync(key_data: str, env: SshEnv) -> None: + """Synchronously add the SSH key to the agent then remove it.""" + with tempfile.NamedTemporaryFile(mode="w", delete=True) as temp_key_file: temp_key_file.write(key_data) temp_key_file.write("\n") temp_key_file.flush() @@ -130,19 +159,90 @@ async def add_ssh_key_to_agent(key_data: str, env: SshEnv) -> None: raise try: - process = await asyncio.create_subprocess_exec( - "ssh-add", - temp_key_file.name, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + result = subprocess.run( + ["ssh-add", temp_key_file.name], + capture_output=True, + text=True, env=env.to_dict(), + check=False, ) - _, stderr = await process.communicate() - if process.returncode != 0: - raise Exception(f"Failed to add SSH key: {stderr.decode().strip()}") + if result.returncode != 0: + raise RuntimeError(f"Failed to add SSH key: {result.stderr.strip()}") logger.info("Added SSH key to agent") except Exception as e: logger.error("Error adding SSH key", error=e) raise + + +async def add_ssh_key_to_agent(key_data: str, env: SshEnv) -> None: + """Asynchronously add the SSH key to the agent then remove it.""" + return await asyncio.to_thread(add_ssh_key_to_agent_sync, key_data, env) + + +@asynccontextmanager +async def temp_key_file(key_content: str) -> AsyncIterator[str]: + """Create a temporary file containing an SSH key with secure permissions. + + Args: + key_content: The SSH key content to write to the temporary file + + Returns: + An SSH command string configured to use the temporary key file + + Raises: + OSError: If unable to create temp file or set permissions + """ + async with aiofiles.tempfile.NamedTemporaryFile(mode="w", delete=True) as f: + # Write key content + await f.write(key_content) + await f.flush() + + # Set strict permissions (important!) + os.chmod(f.name, 0o600) + + # Use the key file in SSH command with more permissive host key checking + ssh_cmd = ( + f"ssh -i {f.name} -o IdentitiesOnly=yes " + "-o StrictHostKeyChecking=accept-new " + f"-o UserKnownHostsFile={Path.home().joinpath('.ssh/known_hosts')!s}" + ) + yield ssh_cmd + + +@asynccontextmanager +async def opt_temp_key_file( + git_url: GitUrl | None, + session: AsyncSession, + role: Role | None = None, +) -> AsyncIterator[str | None]: + """Context manager for optional SSH key file.""" + if git_url is None: + yield None + else: + role = role or ctx_role.get() + service = SecretsService(session=session, role=role) + ssh_key = await service.get_ssh_key() + async with temp_key_file(key_content=ssh_key.reveal().value) as ssh_cmd: + yield ssh_cmd + + +@asynccontextmanager +async def ssh_context( + *, + git_url: GitUrl | None = None, + session: AsyncSession, + role: Role | None = None, +) -> AsyncIterator[SshEnv | None]: + """Context manager for SSH environment variables.""" + if git_url is None: + yield None + else: + logger.info("Getting SSH key", role=role, git_url=git_url) + sec_svc = SecretsService(session, role=role) + secret = await sec_svc.get_ssh_key() + async with temporary_ssh_agent() as env: + await add_ssh_key_to_agent(secret.reveal().value, env=env) + await add_host_to_known_hosts(git_url.host, env=env) + yield env diff --git a/tracecat/types/exceptions.py b/tracecat/types/exceptions.py index 20edb3691..af519efbe 100644 --- a/tracecat/types/exceptions.py +++ b/tracecat/types/exceptions.py @@ -92,3 +92,7 @@ class WrappedExecutionError(TracecatException): def __init__(self, error: Any): self.error = error + + +class TracecatSettingsError(TracecatException): + """Exception raised when a setting error occurs."""