diff --git a/packages/grid/helm/syft/templates/backend-service-account.yaml b/packages/grid/helm/syft/templates/backend-service-account.yaml index 35be3230bd5..97608b4fd4e 100644 --- a/packages/grid/helm/syft/templates/backend-service-account.yaml +++ b/packages/grid/helm/syft/templates/backend-service-account.yaml @@ -36,7 +36,7 @@ metadata: app.kubernetes.io/managed-by: Helm rules: - apiGroups: [""] - resources: ["pods", "configmaps"] + resources: ["pods", "configmaps", "secrets"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] - apiGroups: [""] resources: ["pods/log"] diff --git a/packages/syft/src/syft/custom_worker/runner_k8s.py b/packages/syft/src/syft/custom_worker/runner_k8s.py index 3f0824182fb..0838836989b 100644 --- a/packages/syft/src/syft/custom_worker/runner_k8s.py +++ b/packages/syft/src/syft/custom_worker/runner_k8s.py @@ -1,5 +1,7 @@ # stdlib +import base64 import copy +import json import os from time import sleep from typing import List @@ -10,6 +12,7 @@ import kr8s from kr8s.objects import APIObject from kr8s.objects import Pod +from kr8s.objects import Secret from kr8s.objects import StatefulSet # relative @@ -27,16 +30,35 @@ def create_pool( tag: str, replicas: int = 1, env_vars: Optional[dict] = None, + reg_username: Optional[str] = None, + reg_password: Optional[str] = None, + reg_url: Optional[str] = None, **kwargs, ) -> StatefulSet: + # create pull secret if registry credentials are passed + pull_secret = None + if reg_username and reg_password and reg_url: + pull_secret = self._create_image_pull_secret( + pool_name, + reg_username, + reg_password, + reg_url, + ) + + # create a stateful set deployment deployment = self._create_stateful_set( pool_name, tag, replicas, env_vars, + pull_secret=pull_secret, **kwargs, ) + + # wait for replicas to be available and ready self.wait(deployment, available_replicas=replicas) + + # return return deployment def scale_pool(self, pool_name: str, replicas: int) -> Optional[StatefulSet]: @@ -57,8 +79,11 @@ def delete_pool(self, pool_name: str) -> bool: selector = {"app.kubernetes.io/component": pool_name} for _set in self.client.get("statefulsets", label_selector=selector): _set.delete() - return True - return False + + for _secret in self.client.get("secrets", label_selector=selector): + _secret.delete() + + return True def delete_pod(self, pod_name: str) -> bool: pods = self.client.get("pods", pod_name) @@ -99,7 +124,7 @@ def wait( self, deployment: StatefulSet, available_replicas: int, - timeout: int = 60, + timeout: int = 300, ) -> None: # TODO: Report wait('jsonpath=') bug to kr8s # Until then this is the substitute implementation @@ -133,17 +158,50 @@ def _get_obj_from_list(self, objs: List[dict], name: str) -> dict: if obj.name == name: return obj + def _create_image_pull_secret( + self, + pool_name: str, + reg_username: str, + reg_password: str, + reg_url: str, + **kwargs, + ): + _secret = Secret( + { + "metadata": { + "name": f"pull-secret-{pool_name}", + "labels": { + "app.kubernetes.io/name": KUBERNETES_NAMESPACE, + "app.kubernetes.io/component": pool_name, + "app.kubernetes.io/managed-by": "kr8s", + }, + }, + "type": "kubernetes.io/dockerconfigjson", + "data": { + ".dockerconfigjson": self._create_dockerconfig_json( + reg_username, + reg_password, + reg_url, + ) + }, + } + ) + + return self._create_or_get(_secret) + def _create_stateful_set( self, pool_name: str, tag: str, replicas=1, env_vars: Optional[dict] = None, + pull_secret: Optional[Secret] = None, **kwargs, ) -> StatefulSet: """Create a stateful set for a pool""" env_vars = env_vars or {} + pull_secret_obj = None _pod = Pod.get(self._current_pod_name()) @@ -170,6 +228,13 @@ def _create_stateful_set( for k, v in env_vars.items(): env_clone.append({"name": k, "value": v}) + if pull_secret: + pull_secret_obj = [ + { + "name": pull_secret.name, + } + ] + stateful_set = StatefulSet( { "metadata": { @@ -198,12 +263,14 @@ def _create_stateful_set( "containers": [ { "name": pool_name, + "imagePullPolicy": "IfNotPresent", "image": tag, "env": env_clone, "volumeMounts": [creds_volume_mount], } ], "volumes": [creds_volume], + "imagePullSecrets": pull_secret_obj, }, }, }, @@ -217,3 +284,22 @@ def _create_or_get(self, obj: APIObject) -> APIObject: else: obj.refresh() return obj + + def _create_dockerconfig_json( + self, + reg_username: str, + reg_password: str, + reg_url: str, + ): + config = { + "auths": { + reg_url: { + "username": reg_username, + "password": reg_password, + "auth": base64.b64encode( + f"{reg_username}:{reg_password}".encode() + ).decode(), + } + } + } + return base64.b64encode(json.dumps(config).encode()).decode() diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index b606512780a..699e577e2b9 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -235,6 +235,8 @@ def _run( push_result = worker_image_service.push( service_context, image=worker_image.id, + username=context.extra_kwargs.get("reg_username", None), + password=context.extra_kwargs.get("reg_password", None), ) if isinstance(push_result, SyftError): @@ -299,6 +301,8 @@ def _run( name=self.pool_name, image_uid=self.image_uid, num_workers=self.num_workers, + reg_username=context.extra_kwargs.get("reg_username", None), + reg_password=context.extra_kwargs.get("reg_password", None), ) if isinstance(result, SyftError): return Err(result) @@ -487,7 +491,12 @@ def status(self) -> RequestStatus: return request_status - def approve(self, disable_warnings: bool = False, approve_nested: bool = False): + def approve( + self, + disable_warnings: bool = False, + approve_nested: bool = False, + **kwargs: dict, + ): api = APIRegistry.api_for( self.node_uid, self.syft_client_verify_key, @@ -518,7 +527,7 @@ def approve(self, disable_warnings: bool = False, approve_nested: bool = False): prompt_warning_message(message=message, confirm=True) print(f"Approving request for domain {api.node_name}") - return api.services.request.apply(self.id) + return api.services.request.apply(self.id, **kwargs) def deny(self, reason: str): """Denies the particular request. diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index c0bb7aea4cf..b9bfe8761e6 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -233,11 +233,16 @@ def filter_all_info( name="apply", ) def apply( - self, context: AuthedServiceContext, uid: UID + self, + context: AuthedServiceContext, + uid: UID, + **kwargs: dict, ) -> Union[SyftSuccess, SyftError]: request = self.stash.get_by_uid(context.credentials, uid) if request.is_ok(): request = request.ok() + + context.extra_kwargs = kwargs result = request.apply(context=context) filter_by_obj = context.node.get_service_method( diff --git a/packages/syft/src/syft/service/worker/image_registry.py b/packages/syft/src/syft/service/worker/image_registry.py index 091b876116c..cf3c36c0e0d 100644 --- a/packages/syft/src/syft/service/worker/image_registry.py +++ b/packages/syft/src/syft/service/worker/image_registry.py @@ -1,3 +1,9 @@ +# stdlib +from urllib.parse import urlparse + +# third party +from pydantic import validator + # relative from ...serde.serializable import serializable from ...types.syft_object import SYFT_OBJECT_VERSION_1 @@ -18,13 +24,21 @@ class SyftImageRegistry(SyftObject): id: UID url: str + @validator("url") + def validate_url(cls, val: str): + if val.startswith("http") or "://" in val: + raise ValueError("Registry URL must be a valid RFC 3986 URI") + return val + @classmethod def from_url(cls, full_str: str): - return cls(id=UID(), url=full_str) + if "://" not in full_str: + full_str = f"http://{full_str}" + + parsed = urlparse(full_str) - @property - def tls_enabled(self) -> bool: - return self.url.startswith("https") + # netloc includes the host & port, so local dev should work as expected + return cls(id=UID(), url=parsed.netloc) def __hash__(self) -> int: return hash(self.url + str(self.tls_enabled)) diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 5d3f30121ed..8c628f84486 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -39,11 +39,15 @@ def add( context: AuthedServiceContext, url: str, ) -> Union[SyftSuccess, SyftError]: - registry = SyftImageRegistry.from_url(url) + try: + registry = SyftImageRegistry.from_url(url) + except Exception as e: + return SyftError(message=f"Failed to create registry. {e}") + res = self.stash.set(context.credentials, registry) if res.is_err(): - return SyftError(message=res.err()) + return SyftError(message=f"Failed to create registry. {res.err()}") return SyftSuccess( message=f"Image Registry ID: {registry.id} created successfully" diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 5cf42a12883..9f677eb7f98 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -261,6 +261,10 @@ def create_kubernetes_pool( replicas: int, queue_port: int, debug: bool, + reg_username: Optional[str] = None, + reg_password: Optional[str] = None, + reg_url: Optional[str] = None, + **kwargs, ): pool = None error = False @@ -285,6 +289,9 @@ def create_kubernetes_pool( "CREATE_PRODUCER": "False", "INMEMORY_WORKERS": "False", }, + reg_username=reg_username, + reg_password=reg_password, + reg_url=reg_url, ) except Exception as e: error = True @@ -321,9 +328,9 @@ def run_workers_in_kubernetes( queue_port: int, start_idx=0, debug: bool = False, - username: Optional[str] = None, - password: Optional[str] = None, - registry_url: Optional[str] = None, + reg_username: Optional[str] = None, + reg_password: Optional[str] = None, + reg_url: Optional[str] = None, **kwargs, ) -> Union[List[ContainerSpawnStatus], SyftError]: spawn_status = [] @@ -331,12 +338,15 @@ def run_workers_in_kubernetes( if start_idx == 0: pool_pods = create_kubernetes_pool( - runner, - worker_image, - pool_name, - worker_count, - queue_port, - debug, + runner=runner, + worker_image=worker_image, + pool_name=pool_name, + replicas=worker_count, + queue_port=queue_port, + debug=debug, + reg_username=reg_username, + reg_password=reg_password, + reg_url=reg_url, ) else: pool_pods = scale_kubernetes_pool(runner, pool_name, worker_count) @@ -412,9 +422,9 @@ def run_containers( queue_port: int, dev_mode: bool = False, start_idx: int = 0, - username: Optional[str] = None, - password: Optional[str] = None, - registry_url: Optional[str] = None, + reg_username: Optional[str] = None, + reg_password: Optional[str] = None, + reg_url: Optional[str] = None, ) -> Union[List[ContainerSpawnStatus], SyftError]: results = [] @@ -435,9 +445,9 @@ def run_containers( pool_name=pool_name, queue_port=queue_port, debug=dev_mode, - username=username, - password=password, - registry_url=registry_url, + username=reg_username, + password=reg_password, + registry_url=reg_url, ) results.append(spawn_result) elif orchestration == WorkerOrchestrationType.KUBERNETES: @@ -448,6 +458,9 @@ def run_containers( queue_port=queue_port, debug=dev_mode, start_idx=start_idx, + reg_username=reg_username, + reg_password=reg_password, + reg_url=reg_url, ) return results diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 3787ab62e95..806881547da 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -349,6 +349,8 @@ def add_workers( number: int, pool_id: Optional[UID] = None, pool_name: Optional[str] = None, + reg_username: Optional[str] = None, + reg_password: Optional[str] = None, ) -> Union[List[ContainerSpawnStatus], SyftError]: """Add workers to existing worker pool. @@ -406,6 +408,8 @@ def add_workers( worker_cnt=number, worker_image=worker_image, worker_stash=worker_stash, + reg_username=reg_username, + reg_password=reg_password, ) if isinstance(result, SyftError): @@ -570,9 +574,9 @@ def _create_workers_in_pool( orchestration=get_orchestration_type(), queue_port=queue_port, dev_mode=context.node.dev_mode, - username=reg_username, - password=reg_password, - registry_url=worker_image.image_identifier.registry_host, + reg_username=reg_username, + reg_password=reg_password, + reg_url=worker_image.image_identifier.registry_host, ) if isinstance(result, SyftError): return result