Skip to content

Commit

Permalink
feat(engine): Add and update git and ssh modules (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt authored Jan 15, 2025
1 parent a12b87c commit cc9c7c2
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 45 deletions.
1 change: 1 addition & 0 deletions frontend/src/client/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2353,6 +2353,7 @@ export const $Role = {
"tracecat-schedule-runner",
"tracecat-service",
"tracecat-executor",
"tracecat-bootstrap",
],
title: "Service Id",
},
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/client/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ export type Role = {
| "tracecat-schedule-runner"
| "tracecat-service"
| "tracecat-executor"
| "tracecat-bootstrap"
}

export type type2 = "user" | "service"
Expand All @@ -792,6 +793,7 @@ export type service_id =
| "tracecat-schedule-runner"
| "tracecat-service"
| "tracecat-executor"
| "tracecat-bootstrap"

/**
* This object contains all the information needed to execute an action.
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,17 @@ async def test_registry_async_function_can_be_called(mock_package):
host="github.com",
org="org",
repo="repo",
branch="main",
ref=None,
),
),
# GitHub (with branch)
# GitHub (with branch/sha)
(
"git+ssh://[email protected]/org/repo@branch",
"git+ssh://[email protected]/org/repo@branchOrSHAOrTag",
GitUrl(
host="github.com",
org="org",
repo="repo",
branch="branch",
ref="branchOrSHAOrTag",
),
),
# GitLab
Expand All @@ -174,7 +174,7 @@ async def test_registry_async_function_can_be_called(mock_package):
host="gitlab.com",
org="org",
repo="repo",
branch="main",
ref=None,
),
),
# GitLab (with branch)
Expand All @@ -184,7 +184,7 @@ async def test_registry_async_function_can_be_called(mock_package):
host="gitlab.com",
org="org",
repo="repo",
branch="branch",
ref="branch",
),
),
],
Expand Down
2 changes: 1 addition & 1 deletion tracecat/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def bootstrap_role():
return Role(
type="service",
access_level=AccessLevel.ADMIN,
service_id="tracecat-api",
service_id="tracecat-bootstrap",
)


Expand Down
4 changes: 2 additions & 2 deletions tracecat/dsl/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def _queue_tasks(
)
async with asyncio.TaskGroup() as tg:
for next_ref, edge_type in next_tasks:
self.logger.warning("Processing next task", ref=ref, next_ref=next_ref)
self.logger.debug("Processing next task", ref=ref, next_ref=next_ref)
edge = DSLEdge(src=ref, dst=next_ref, type=edge_type)
if unreachable and edge in unreachable:
self._mark_edge(edge, EdgeMarker.SKIPPED)
Expand Down Expand Up @@ -272,7 +272,7 @@ def _is_reachable(self, task: ActionStatement) -> bool:
# Root nodes are always reachable
return True
elif n_deps == 1:
logger.warning("Task has only 1 dependency", task=task)
logger.debug("Task has only 1 dependency", task=task)
# If there's only 1 dependency, the node is reachable only if the
# dependency was successful ignoring the join strategy.
dep_ref = task.depends_on[0]
Expand Down
89 changes: 79 additions & 10 deletions tracecat/git.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
import asyncio
import re
from dataclasses import dataclass
from typing import cast

from tracecat import config
from tracecat.contexts import ctx_role
from tracecat.logger import logger
from tracecat.registry.repositories.service import RegistryReposService
from tracecat.settings.service import get_setting
from tracecat.ssh import SshEnv
from tracecat.types.auth import Role
from tracecat.types.exceptions import TracecatSettingsError

GIT_SSH_URL_REGEX = re.compile(
r"^git\+ssh://git@(?P<host>[^/]+)/(?P<org>[^/]+)/(?P<repo>[^/@]+?)(?:\.git)?(?:@(?P<branch>[^/]+))?$"
r"^git\+ssh://git@(?P<host>[^/]+)/(?P<org>[^/]+)/(?P<repo>[^/@]+?)(?:\.git)?(?:@(?P<ref>[^/]+))?$"
)
"""Git SSH URL with git user and optional ref."""


@dataclass
class GitUrl:
host: str
org: str
repo: str
branch: str
ref: str | None = None

def to_url(self) -> str:
base = f"git+ssh://git@{self.host}/{self.org}/{self.repo}.git"
return f"{base}@{self.ref}" if self.ref else base


async def get_git_repository_sha(repo_url: str, env: SshEnv) -> str:
Expand All @@ -38,8 +50,8 @@ async def get_git_repository_sha(repo_url: str, env: SshEnv) -> str:
raise RuntimeError(f"Failed to get repository SHA: {error_message}")

# The output format is: "<SHA>\tHEAD"
sha = stdout.decode().split()[0]
return sha
ref = stdout.decode().split()[0]
return ref

except Exception as e:
logger.error("Error getting repository SHA", error=str(e))
Expand All @@ -63,16 +75,73 @@ def parse_git_url(url: str, *, allowed_domains: set[str] | None = None) -> GitUr

if match := GIT_SSH_URL_REGEX.match(url):
host = match.group("host")
org = match.group("org")
repo = match.group("repo")
ref = match.group("ref")

if (
not isinstance(host, str)
or not isinstance(org, str)
or not isinstance(repo, str)
):
raise ValueError(f"Invalid Git URL: {url}")

if allowed_domains and host not in allowed_domains:
raise ValueError(
f"Domain {host} not in allowed domains. Must be configured in `git_allowed_domains` organization setting."
)

return GitUrl(
host=host,
org=match.group("org"),
repo=match.group("repo"),
branch=match.group("branch") or "main",
)
return GitUrl(host=host, org=org, repo=repo, ref=ref)

raise ValueError(f"Unsupported URL format: {url}. Must be a valid Git SSH URL.")


async def prepare_git_url(role: Role | None = None) -> GitUrl | None:
"""Construct the runtime environment
Deps:
In the new pull-model registry, the execution environment is ALL the registries
1. Tracecat registry
2. User's custom template registry
3. User's custom UDF registry (github)
Why?
Since we no longer depend on the user to push to executor, the db repos are now
the source of truth.
"""
role = role or ctx_role.get()

# Handle the git repo
url = await get_setting(
"git_repo_url",
# TODO: Deprecate in future version
default=config.TRACECAT__REMOTE_REPOSITORY_URL,
)
if not url or not isinstance(url, str):
logger.debug("No git URL found")
return None

logger.debug("Runtime environment", url=url)

allowed_domains_setting = await get_setting(
"git_allowed_domains",
# TODO: Deprecate in future version
default=config.TRACECAT__ALLOWED_GIT_DOMAINS,
)
allowed_domains = cast(set[str], allowed_domains_setting or {"github.com"})

# Grab the sha
# Find the repository that has the same origin
sha = None
async with RegistryReposService.with_session(role=role) as service:
repo = await service.get_repository(origin=url)
sha = repo.commit_sha if repo else None

try:
# Validate
git_url = parse_git_url(url, allowed_domains=allowed_domains)
except ValueError as e:
raise TracecatSettingsError(
"Invalid Git repository URL. Please provide a valid Git SSH URL (git+ssh)."
) from e
git_url.ref = sha
return git_url
1 change: 1 addition & 0 deletions tracecat/identifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"tracecat-schedule-runner",
"tracecat-service",
"tracecat-executor",
"tracecat-bootstrap",
]

__all__ = [
Expand Down
2 changes: 0 additions & 2 deletions tracecat/registry/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
host = git_url.host
org = git_url.org
repo_name = git_url.repo
branch = git_url.branch
except ValueError as e:
raise RegistryError(
"Invalid Git repository URL. Please provide a valid Git SSH URL (git+ssh)."
Expand All @@ -283,7 +282,6 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
org=org,
repo=repo_name,
package_name=package_name,
ref=branch,
)

cleaned_url = self.safe_remote_url(self._origin)
Expand Down
2 changes: 2 additions & 0 deletions tracecat/settings/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tracecat import config
from tracecat.authz.controls import require_access_level
from tracecat.contexts import ctx_role
from tracecat.db.schemas import OrganizationSetting
from tracecat.logger import logger
from tracecat.secrets.encryption import decrypt_value, encrypt_value
Expand Down Expand Up @@ -264,6 +265,7 @@ async def get_setting(
default: Any | None = None,
) -> Any | None:
"""Shorthand to get a setting value from the database."""
role = role or ctx_role.get()

# If we have an environment override, use it
if override_val := get_setting_override(key):
Expand Down
Loading

0 comments on commit cc9c7c2

Please sign in to comment.