Skip to content

Commit

Permalink
Merge pull request #8444 from OpenMined/yash/k8s-pull-secrets
Browse files Browse the repository at this point in the history
Add pullImageSecrets for authenticated registries
  • Loading branch information
rasswanth-s authored Feb 2, 2024
2 parents 60f7184 + 1bf9644 commit 9f1afaf
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ metadata:
app.kubernetes.io/managed-by: Helm
rules:
- apiGroups: [""]
resources: ["pods", "configmaps"]
resources: ["pods", "configmaps", "secrets"]
verbs: ["create", "get", "list", "watch", "update", "patch", "delete"]
- apiGroups: [""]
resources: ["pods/log"]
Expand Down
92 changes: 89 additions & 3 deletions packages/syft/src/syft/custom_worker/runner_k8s.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# stdlib
import base64
import copy
import json
import os
from time import sleep
from typing import List
Expand All @@ -10,6 +12,7 @@
import kr8s
from kr8s.objects import APIObject
from kr8s.objects import Pod
from kr8s.objects import Secret
from kr8s.objects import StatefulSet

# relative
Expand All @@ -27,16 +30,35 @@ def create_pool(
tag: str,
replicas: int = 1,
env_vars: Optional[dict] = None,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
**kwargs,
) -> StatefulSet:
# create pull secret if registry credentials are passed
pull_secret = None
if reg_username and reg_password and reg_url:
pull_secret = self._create_image_pull_secret(
pool_name,
reg_username,
reg_password,
reg_url,
)

# create a stateful set deployment
deployment = self._create_stateful_set(
pool_name,
tag,
replicas,
env_vars,
pull_secret=pull_secret,
**kwargs,
)

# wait for replicas to be available and ready
self.wait(deployment, available_replicas=replicas)

# return
return deployment

def scale_pool(self, pool_name: str, replicas: int) -> Optional[StatefulSet]:
Expand All @@ -57,8 +79,11 @@ def delete_pool(self, pool_name: str) -> bool:
selector = {"app.kubernetes.io/component": pool_name}
for _set in self.client.get("statefulsets", label_selector=selector):
_set.delete()
return True
return False

for _secret in self.client.get("secrets", label_selector=selector):
_secret.delete()

return True

def delete_pod(self, pod_name: str) -> bool:
pods = self.client.get("pods", pod_name)
Expand Down Expand Up @@ -99,7 +124,7 @@ def wait(
self,
deployment: StatefulSet,
available_replicas: int,
timeout: int = 60,
timeout: int = 300,
) -> None:
# TODO: Report wait('jsonpath=') bug to kr8s
# Until then this is the substitute implementation
Expand Down Expand Up @@ -133,17 +158,50 @@ def _get_obj_from_list(self, objs: List[dict], name: str) -> dict:
if obj.name == name:
return obj

def _create_image_pull_secret(
self,
pool_name: str,
reg_username: str,
reg_password: str,
reg_url: str,
**kwargs,
):
_secret = Secret(
{
"metadata": {
"name": f"pull-secret-{pool_name}",
"labels": {
"app.kubernetes.io/name": KUBERNETES_NAMESPACE,
"app.kubernetes.io/component": pool_name,
"app.kubernetes.io/managed-by": "kr8s",
},
},
"type": "kubernetes.io/dockerconfigjson",
"data": {
".dockerconfigjson": self._create_dockerconfig_json(
reg_username,
reg_password,
reg_url,
)
},
}
)

return self._create_or_get(_secret)

def _create_stateful_set(
self,
pool_name: str,
tag: str,
replicas=1,
env_vars: Optional[dict] = None,
pull_secret: Optional[Secret] = None,
**kwargs,
) -> StatefulSet:
"""Create a stateful set for a pool"""

env_vars = env_vars or {}
pull_secret_obj = None

_pod = Pod.get(self._current_pod_name())

Expand All @@ -170,6 +228,13 @@ def _create_stateful_set(
for k, v in env_vars.items():
env_clone.append({"name": k, "value": v})

if pull_secret:
pull_secret_obj = [
{
"name": pull_secret.name,
}
]

stateful_set = StatefulSet(
{
"metadata": {
Expand Down Expand Up @@ -198,12 +263,14 @@ def _create_stateful_set(
"containers": [
{
"name": pool_name,
"imagePullPolicy": "IfNotPresent",
"image": tag,
"env": env_clone,
"volumeMounts": [creds_volume_mount],
}
],
"volumes": [creds_volume],
"imagePullSecrets": pull_secret_obj,
},
},
},
Expand All @@ -217,3 +284,22 @@ def _create_or_get(self, obj: APIObject) -> APIObject:
else:
obj.refresh()
return obj

def _create_dockerconfig_json(
self,
reg_username: str,
reg_password: str,
reg_url: str,
):
config = {
"auths": {
reg_url: {
"username": reg_username,
"password": reg_password,
"auth": base64.b64encode(
f"{reg_username}:{reg_password}".encode()
).decode(),
}
}
}
return base64.b64encode(json.dumps(config).encode()).decode()
13 changes: 11 additions & 2 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def _run(
push_result = worker_image_service.push(
service_context,
image=worker_image.id,
username=context.extra_kwargs.get("reg_username", None),
password=context.extra_kwargs.get("reg_password", None),
)

if isinstance(push_result, SyftError):
Expand Down Expand Up @@ -299,6 +301,8 @@ def _run(
name=self.pool_name,
image_uid=self.image_uid,
num_workers=self.num_workers,
reg_username=context.extra_kwargs.get("reg_username", None),
reg_password=context.extra_kwargs.get("reg_password", None),
)
if isinstance(result, SyftError):
return Err(result)
Expand Down Expand Up @@ -487,7 +491,12 @@ def status(self) -> RequestStatus:

return request_status

def approve(self, disable_warnings: bool = False, approve_nested: bool = False):
def approve(
self,
disable_warnings: bool = False,
approve_nested: bool = False,
**kwargs: dict,
):
api = APIRegistry.api_for(
self.node_uid,
self.syft_client_verify_key,
Expand Down Expand Up @@ -518,7 +527,7 @@ def approve(self, disable_warnings: bool = False, approve_nested: bool = False):
prompt_warning_message(message=message, confirm=True)

print(f"Approving request for domain {api.node_name}")
return api.services.request.apply(self.id)
return api.services.request.apply(self.id, **kwargs)

def deny(self, reason: str):
"""Denies the particular request.
Expand Down
7 changes: 6 additions & 1 deletion packages/syft/src/syft/service/request/request_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,16 @@ def filter_all_info(
name="apply",
)
def apply(
self, context: AuthedServiceContext, uid: UID
self,
context: AuthedServiceContext,
uid: UID,
**kwargs: dict,
) -> Union[SyftSuccess, SyftError]:
request = self.stash.get_by_uid(context.credentials, uid)
if request.is_ok():
request = request.ok()

context.extra_kwargs = kwargs
result = request.apply(context=context)

filter_by_obj = context.node.get_service_method(
Expand Down
22 changes: 18 additions & 4 deletions packages/syft/src/syft/service/worker/image_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# stdlib
from urllib.parse import urlparse

# third party
from pydantic import validator

# relative
from ...serde.serializable import serializable
from ...types.syft_object import SYFT_OBJECT_VERSION_1
Expand All @@ -18,13 +24,21 @@ class SyftImageRegistry(SyftObject):
id: UID
url: str

@validator("url")
def validate_url(cls, val: str):
if val.startswith("http") or "://" in val:
raise ValueError("Registry URL must be a valid RFC 3986 URI")
return val

@classmethod
def from_url(cls, full_str: str):
return cls(id=UID(), url=full_str)
if "://" not in full_str:
full_str = f"http://{full_str}"

parsed = urlparse(full_str)

@property
def tls_enabled(self) -> bool:
return self.url.startswith("https")
# netloc includes the host & port, so local dev should work as expected
return cls(id=UID(), url=parsed.netloc)

def __hash__(self) -> int:
return hash(self.url + str(self.tls_enabled))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ def add(
context: AuthedServiceContext,
url: str,
) -> Union[SyftSuccess, SyftError]:
registry = SyftImageRegistry.from_url(url)
try:
registry = SyftImageRegistry.from_url(url)
except Exception as e:
return SyftError(message=f"Failed to create registry. {e}")

res = self.stash.set(context.credentials, registry)

if res.is_err():
return SyftError(message=res.err())
return SyftError(message=f"Failed to create registry. {res.err()}")

return SyftSuccess(
message=f"Image Registry ID: {registry.id} created successfully"
Expand Down
43 changes: 28 additions & 15 deletions packages/syft/src/syft/service/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def create_kubernetes_pool(
replicas: int,
queue_port: int,
debug: bool,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
**kwargs,
):
pool = None
error = False
Expand All @@ -285,6 +289,9 @@ def create_kubernetes_pool(
"CREATE_PRODUCER": "False",
"INMEMORY_WORKERS": "False",
},
reg_username=reg_username,
reg_password=reg_password,
reg_url=reg_url,
)
except Exception as e:
error = True
Expand Down Expand Up @@ -321,22 +328,25 @@ def run_workers_in_kubernetes(
queue_port: int,
start_idx=0,
debug: bool = False,
username: Optional[str] = None,
password: Optional[str] = None,
registry_url: Optional[str] = None,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
**kwargs,
) -> Union[List[ContainerSpawnStatus], SyftError]:
spawn_status = []
runner = KubernetesRunner()

if start_idx == 0:
pool_pods = create_kubernetes_pool(
runner,
worker_image,
pool_name,
worker_count,
queue_port,
debug,
runner=runner,
worker_image=worker_image,
pool_name=pool_name,
replicas=worker_count,
queue_port=queue_port,
debug=debug,
reg_username=reg_username,
reg_password=reg_password,
reg_url=reg_url,
)
else:
pool_pods = scale_kubernetes_pool(runner, pool_name, worker_count)
Expand Down Expand Up @@ -412,9 +422,9 @@ def run_containers(
queue_port: int,
dev_mode: bool = False,
start_idx: int = 0,
username: Optional[str] = None,
password: Optional[str] = None,
registry_url: Optional[str] = None,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
) -> Union[List[ContainerSpawnStatus], SyftError]:
results = []

Expand All @@ -435,9 +445,9 @@ def run_containers(
pool_name=pool_name,
queue_port=queue_port,
debug=dev_mode,
username=username,
password=password,
registry_url=registry_url,
username=reg_username,
password=reg_password,
registry_url=reg_url,
)
results.append(spawn_result)
elif orchestration == WorkerOrchestrationType.KUBERNETES:
Expand All @@ -448,6 +458,9 @@ def run_containers(
queue_port=queue_port,
debug=dev_mode,
start_idx=start_idx,
reg_username=reg_username,
reg_password=reg_password,
reg_url=reg_url,
)

return results
Expand Down
Loading

0 comments on commit 9f1afaf

Please sign in to comment.