From eb86cd576aa34d1af03f69ace443268bbad111ef Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 20 Dec 2024 09:43:44 +0100 Subject: [PATCH 01/13] Fix the onboarding state to account for zenml login --- src/zenml/zen_server/auth.py | 23 +++++++++++++++++-- .../zen_server/routers/auth_endpoints.py | 3 ++- src/zenml/zen_stores/rest_zen_store.py | 5 ++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 80290f091c..18a5b6774e 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -37,7 +37,12 @@ LOGIN, VERSION_1, ) -from zenml.enums import AuthScheme, ExecutionStatus, OAuthDeviceStatus +from zenml.enums import ( + AuthScheme, + ExecutionStatus, + OAuthDeviceStatus, + OnboardingStep, +) from zenml.exceptions import ( AuthorizationException, CredentialsNotValid, @@ -630,12 +635,15 @@ def authenticate_device(client_id: UUID, device_code: str) -> AuthContext: return AuthContext(user=device_model.user, device=device_model) -def authenticate_external_user(external_access_token: str) -> AuthContext: +def authenticate_external_user( + external_access_token: str, request: Request +) -> AuthContext: """Implement external authentication. Args: external_access_token: The access token used to authenticate the user to the external authenticator. + request: The request object. Returns: The authentication context reflecting the authenticated user. @@ -761,6 +769,17 @@ def authenticate_external_user(external_access_token: str) -> AuthContext: ) context.alias(user_id=external_user.id, previous_id=user.id) + # This is the best spot to update the onboarding state to mark the + # "zenml login" step as completed for ZenML Pro servers, because the + # user has just successfully logged in. However, we need to differentiate + # between web clients (i.e. the dashboard) and CLI clients (i.e. the + # zenml CLI). + user_agent = request.headers.get("User-Agent", "").lower() + if "zenml/" in user_agent: + store.update_onboarding_state( + completed_steps={OnboardingStep.DEVICE_VERIFIED} + ) + return AuthContext(user=user) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index a1339c10bf..fb0eee4391 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -287,7 +287,8 @@ def token( return OAuthRedirectResponse(authorization_url=authorization_url) auth_context = authenticate_external_user( - external_access_token=external_access_token + external_access_token=external_access_token, + request=request, ) else: diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index e3e29759b2..a7f013904a 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -4218,6 +4218,11 @@ def session(self) -> requests.Session: self._session.mount("https://", HTTPAdapter(max_retries=retries)) self._session.mount("http://", HTTPAdapter(max_retries=retries)) self._session.verify = self.config.verify_ssl + # Use a custom user agent to identify the ZenML client in the server + # logs. + self._session.headers.update( + {"User-Agent": "zenml/" + zenml.__version__} + ) # Note that we return an unauthenticated session here. An API token # is only fetched and set in the authorization header when and if it is From 479dca830e4ccbb4d2232240a85554f449f1f357 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 11 Dec 2024 18:21:09 +0100 Subject: [PATCH 02/13] Implement cross-domain external authentication with CSRF protection --- pyproject.toml | 4 + src/zenml/models/v2/misc/auth_models.py | 1 + src/zenml/zen_server/auth.py | 95 +++++++++++++++++-- src/zenml/zen_server/csrf.py | 91 ++++++++++++++++++ src/zenml/zen_server/jwt.py | 17 +++- .../zen_server/routers/auth_endpoints.py | 20 +++- src/zenml/zen_server/utils.py | 45 +++++++++ 7 files changed, 263 insertions(+), 10 deletions(-) create mode 100644 src/zenml/zen_server/csrf.py diff --git a/pyproject.toml b/pyproject.toml index a71b71979a..288ca68cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,8 @@ orjson = { version = "~3.10.0", optional = true } Jinja2 = { version = "*", optional = true } ipinfo = { version = ">=4.4.3", optional = true } secure = { version = "~0.3.0", optional = true } +tldextract = { version = "~5.1.0", optional = true } +itsdangerous = { version = "~2.2.0", optional = true } # Optional dependencies for project templates copier = { version = ">=8.1.0", optional = true } @@ -189,6 +191,8 @@ server = [ "Jinja2", "ipinfo", "secure", + "tldextract", + "itsdangerous", ] templates = ["copier", "jinja2-time", "ruff", "pyyaml-include"] terraform = ["python-terraform"] diff --git a/src/zenml/models/v2/misc/auth_models.py b/src/zenml/models/v2/misc/auth_models.py index 2590ef1a2b..b833811773 100644 --- a/src/zenml/models/v2/misc/auth_models.py +++ b/src/zenml/models/v2/misc/auth_models.py @@ -119,6 +119,7 @@ class OAuthTokenResponse(BaseModel): token_type: str expires_in: Optional[int] = None refresh_token: Optional[str] = None + csrf_token: Optional[str] = None scope: Optional[str] = None cookie_name: Optional[str] = None device_id: Optional[UUID] = None diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 18a5b6774e..e02ac10dca 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -16,8 +16,8 @@ from contextvars import ContextVar from datetime import datetime, timedelta from typing import Callable, Optional, Union -from urllib.parse import urlencode -from uuid import UUID +from urllib.parse import urlencode, urlparse +from uuid import UUID, uuid4 import requests from fastapi import Depends, Response @@ -63,9 +63,14 @@ UserUpdate, ) from zenml.zen_server.cache import cache_result +from zenml.zen_server.csrf import CSRFToken from zenml.zen_server.exceptions import http_exception_from_error from zenml.zen_server.jwt import JWTToken -from zenml.zen_server.utils import server_config, zen_store +from zenml.zen_server.utils import ( + is_same_or_subdomain, + server_config, + zen_store, +) logger = get_logger(__name__) @@ -176,6 +181,7 @@ def authenticate_credentials( user_name_or_id: Optional[Union[str, UUID]] = None, password: Optional[str] = None, access_token: Optional[str] = None, + csrf_token: Optional[str] = None, activation_token: Optional[str] = None, ) -> AuthContext: """Verify if user authentication credentials are valid. @@ -194,6 +200,7 @@ def authenticate_credentials( user_name_or_id: The username or user ID. password: The password. access_token: The access token. + csrf_token: The CSRF token. activation_token: The activation token. Returns: @@ -255,6 +262,22 @@ def authenticate_credentials( logger.exception(error) raise CredentialsNotValid(error) + if decoded_token.session_id: + if not csrf_token: + error = "Authentication error: missing CSRF token" + logger.error(error) + raise CredentialsNotValid(error) + + decoded_csrf_token = CSRFToken.decode_token(csrf_token) + + if decoded_csrf_token.session_id != decoded_token.session_id: + error = ( + "Authentication error: CSRF token does not match the " + "access token" + ) + logger.error(error) + raise CredentialsNotValid(error) + try: user_model = zen_store().get_user( user_name_or_id=decoded_token.user_id, include_private=True @@ -820,6 +843,7 @@ def authenticate_api_key( def generate_access_token( user_id: UUID, response: Optional[Response] = None, + request: Optional[Request] = None, device: Optional[OAuthDeviceInternalResponse] = None, api_key: Optional[APIKeyInternalResponse] = None, expires_in: Optional[int] = None, @@ -831,7 +855,11 @@ def generate_access_token( Args: user_id: The ID of the user. - response: The FastAPI response object. + response: The FastAPI response object. If passed, the access + token will also be set as an HTTP only cookie in the response. + request: The FastAPI request object. Used to determine the request + origin and to decide whether to use cross-site security measures for + the access token cookie. device: The device used for authentication. api_key: The service account API key used for authentication. expires_in: The number of seconds until the token expires. If not set, @@ -869,6 +897,46 @@ def generate_access_token( ) expires_in = config.jwt_token_expire_minutes * 60 + # Figure out if this is a same-site request or a cross-site request + same_site = True + if response and request: + # Extract the origin domain from the request; use the referer as a + # fallback + origin_domain: Optional[str] = None + origin = request.headers.get("origin", request.headers.get("referer")) + if origin: + # If the request origin is known, we use it to determine whether + # this is a cross-site request and enable additional security + # measures. + origin_domain = urlparse(origin).netloc + + server_domain: Optional[str] = config.auth_cookie_domain + # If the server's cookie domain is not explicitly set in the + # server's configuration, we use other sources to determine it: + # + # 1. the server's root URL, if set in the server's configuration + # 2. the X-Forwarded-Host header, if set by the reverse proxy + # 3. the request URL, if all else fails + if not server_domain and config.server_url: + server_domain = urlparse(config.server_url).netloc + if not server_domain: + server_domain = request.headers.get( + "x-forwarded-host", request.url.netloc + ) + + # Same-site requests can come from the same domain or from a + # subdomain of the domain used to issue cookies. + if origin_domain and server_domain: + same_site = is_same_or_subdomain(origin_domain, server_domain) + + csrf_token: Optional[str] = None + session_id: Optional[UUID] = None + if not same_site: + # If responding to a cross-site login request, we need to generate and + # sign a CSRF token associated with the authentication session. + session_id = uuid4() + csrf_token = CSRFToken(session_id=session_id).encode() + access_token = JWTToken( user_id=user_id, device_id=device.id if device else None, @@ -876,15 +944,18 @@ def generate_access_token( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, + # Set the session ID if this is a cross-site request + session_id=session_id, ).encode(expires=expires) - if not device and response: + if response: # Also set the access token as an HTTP only cookie in the response response.set_cookie( key=config.get_auth_cookie_name(), value=access_token, httponly=True, - samesite="lax", + secure=not same_site, + samesite="lax" if same_site else "none", max_age=config.jwt_token_expire_minutes * 60 if config.jwt_token_expire_minutes else None, @@ -892,7 +963,10 @@ def generate_access_token( ) return OAuthTokenResponse( - access_token=access_token, expires_in=expires_in, token_type="bearer" + access_token=access_token, + expires_in=expires_in, + token_type="bearer", + csrf_token=csrf_token, ) @@ -953,19 +1027,24 @@ def oauth2_authentication( tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN, ) ), + request: Request = Depends(), ) -> AuthContext: """Authenticates any request to the ZenML server with OAuth2 JWT tokens. Args: token: The JWT bearer token to be authenticated. + request: The FastAPI request object. Returns: The authentication context reflecting the authenticated user. # noqa: DAR401 """ + csrf_token = request.headers.get("X-CSRF-Token") try: - auth_context = authenticate_credentials(access_token=token) + auth_context = authenticate_credentials( + access_token=token, csrf_token=csrf_token + ) except CredentialsNotValid as e: # We want to be very explicit here and return a CredentialsNotValid # exception encoded as a 401 Unauthorized error encoded, so that the diff --git a/src/zenml/zen_server/csrf.py b/src/zenml/zen_server/csrf.py new file mode 100644 index 0000000000..9bf51cab6e --- /dev/null +++ b/src/zenml/zen_server/csrf.py @@ -0,0 +1,91 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""CSRF token utilities module for ZenML server.""" + +from uuid import UUID + +from pydantic import BaseModel + +from zenml.exceptions import CredentialsNotValid +from zenml.logger import get_logger +from zenml.zen_server.utils import server_config + +logger = get_logger(__name__) + + +class CSRFToken(BaseModel): + """Pydantic object representing a CSRF token. + + Attributes: + session_id: The id of the authenticated session. + """ + + session_id: UUID + + @classmethod + def decode_token( + cls, + token: str, + ) -> "CSRFToken": + """Decodes a CSRF token. + + Decodes a CSRF access token and returns a `CSRFToken` object with the + information retrieved from its contents. + + Args: + token: The encoded CSRF token. + + Returns: + The decoded CSRF token. + + Raises: + CredentialsNotValid: If the token is invalid. + """ + from itsdangerous import BadData, BadSignature, URLSafeSerializer + + config = server_config() + + serializer = URLSafeSerializer(config.jwt_secret_key) + try: + # Decode and verify the token + data = serializer.loads(token) + except BadSignature as e: + raise CredentialsNotValid( + "Invalid CSRF token: signature mismatch" + ) from e + except BadData as e: + raise CredentialsNotValid("Invalid CSRF token") from e + + try: + return CSRFToken(session_id=UUID(data)) + except ValueError as e: + raise CredentialsNotValid( + "Invalid CSRF token: the session ID is not a valid UUID" + ) from e + + def encode(self) -> str: + """Creates a CSRF token. + + Encodes, signs and returns a CSRF access token. + + Returns: + The generated CSRF token. + """ + from itsdangerous import URLSafeSerializer + + config = server_config() + + serializer = URLSafeSerializer(config.jwt_secret_key) + token = serializer.dumps(str(self.session_id)) + return token diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index eafa61a784..d1150ac028 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Authentication module for ZenML server.""" +"""JWT utilities module for ZenML server.""" from datetime import datetime, timedelta from typing import ( @@ -45,6 +45,7 @@ class JWTToken(BaseModel): issued. step_run_id: The id of the step run for which the token was issued. + session_id: The id of the authenticated session (used for CSRF). claims: The original token claims. """ @@ -54,6 +55,7 @@ class JWTToken(BaseModel): schedule_id: Optional[UUID] = None pipeline_run_id: Optional[UUID] = None step_run_id: Optional[UUID] = None + session_id: Optional[UUID] = None claims: Dict[str, Any] = {} @classmethod @@ -156,6 +158,16 @@ def decode_token( "UUID" ) + session_id: Optional[UUID] = None + if "session_id" in claims: + try: + session_id = UUID(claims.pop("session_id")) + except ValueError: + raise CredentialsNotValid( + "Invalid JWT token: the session_id claim is not a valid " + "UUID" + ) + return JWTToken( user_id=user_id, device_id=device_id, @@ -163,6 +175,7 @@ def decode_token( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, + session_id=session_id, claims=claims, ) @@ -201,6 +214,8 @@ def encode(self, expires: Optional[datetime] = None) -> str: claims["pipeline_run_id"] = str(self.pipeline_run_id) if self.step_run_id: claims["step_run_id"] = str(self.step_run_id) + if self.session_id: + claims["session_id"] = str(self.session_id) return jwt.encode( claims, diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index fb0eee4391..1754eccef1 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -237,6 +237,8 @@ def token( ValueError: If the grant type is invalid. """ config = server_config() + cookie_response: Optional[Response] = response + if auth_form_data.grant_type == OAuthGrantTypes.OAUTH_PASSWORD: auth_context = authenticate_credentials( user_name_or_id=auth_form_data.username, @@ -248,10 +250,17 @@ def token( client_id=auth_form_data.client_id, device_code=auth_form_data.device_code, ) + # API tokens for authorized device are only meant for non-web clients + # and should not be stored as cookies + cookie_response = None + elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_API_KEY: auth_context = authenticate_api_key( api_key=auth_form_data.api_key, ) + # API tokens for API keys are only meant for non-web clients + # and should not be stored as cookies + cookie_response = None elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_EXTERNAL: assert config.external_cookie_name is not None @@ -291,13 +300,18 @@ def token( request=request, ) + # TODO: no easy way to detect which type of client issued this request + # (web or non-web) in order to decide whether to store the access token + # as a cookie in the response or not. For now, we always assume a web + # client. else: # Shouldn't happen, because we verify all grants in the form data raise ValueError("Invalid grant type.") return generate_access_token( user_id=auth_context.user.id, - response=response, + response=cookie_response, + request=request, device=auth_context.device, api_key=auth_context.api_key, ) @@ -521,6 +535,8 @@ def api_token( return generate_access_token( user_id=token.user_id, expires_in=expires_in, + # Don't include the access token as a cookie in the response + response=None, ).access_token verify_permission( @@ -621,6 +637,8 @@ def api_token( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, + # Don't include the access token as a cookie in the response + response=None, # Never expire the token expires_in=0, ).access_token diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 86414385ff..715fe78c13 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -571,3 +571,48 @@ def is_user_request(request: "Request") -> bool: # If none of the above conditions are met, consider it a user request return True + + +def is_same_or_subdomain(source_domain: str, target_domain: str) -> bool: + """Check if the source domain is the same or a subdomain of the target domain. + + Examples: + is_same_or_subdomain("example.com", "example.com") -> True + is_same_or_subdomain("alpha.example.com", "example.com") -> True + is_same_or_subdomain("alpha.example.com", ".example.com") -> True + is_same_or_subdomain("example.com", "alpha.example.com") -> False + is_same_or_subdomain("alpha.beta.example.com", "beta.example.com") -> True + is_same_or_subdomain("alpha.beta.example.com", "alpha.example.com") -> False + is_same_or_subdomain("alphabeta.gamma.example", "beta.gamma.example") -> False + + Args: + source_domain: The source domain to check. + target_domain: The target domain to compare against. + + Returns: + True if the source domain is the same or a subdomain of the target + domain, False otherwise. + """ + import tldextract + + # Extract the registered domain and suffix for both + src_parts = tldextract.extract(source_domain) + tgt_parts = tldextract.extract(target_domain) + + if src_parts == tgt_parts: + return True # Same domain + + # Reconstruct the base domains (e.g., example.com) + src_base_domain = f"{src_parts.domain}.{src_parts.suffix}" + tgt_base_domain = f"{tgt_parts.domain}.{tgt_parts.suffix}" + + if src_base_domain != tgt_base_domain: + return False # Different base domains + + if tgt_parts.subdomain == "": + return True # Subdomain + + if src_parts.subdomain.endswith(f".{tgt_parts.subdomain.lstrip('.')}"): + return True # Subdomain of subdomain + + return False From 5a99eff5096868931e0a8565f24acf472b2d26cd Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 13 Dec 2024 20:38:10 +0100 Subject: [PATCH 03/13] Remove dependence on external cookie for external auth --- src/zenml/config/server_config.py | 13 --------- src/zenml/login/credentials.py | 1 - src/zenml/login/credentials_store.py | 1 - src/zenml/models/v2/misc/auth_models.py | 1 - src/zenml/zen_server/cloud_utils.py | 12 +++++--- .../deploy/helm/templates/_environment.tpl | 3 -- src/zenml/zen_server/deploy/helm/values.yaml | 5 ---- .../zen_server/routers/auth_endpoints.py | 29 ++++++------------- 8 files changed, 17 insertions(+), 48 deletions(-) diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index e70a481d76..a9ad983645 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -127,10 +127,6 @@ class ServerConfiguration(BaseModel): to use with the `EXTERNAL` authentication scheme. external_user_info_url: The user info URL of an external authenticator service to use with the `EXTERNAL` authentication scheme. - external_cookie_name: The name of the http-only cookie used to store the - bearer token used to authenticate with the external authenticator - service. Must be specified if the `EXTERNAL` authentication scheme - is used. external_server_id: The ID of the ZenML server to use with the `EXTERNAL` authentication scheme. If not specified, the regular ZenML server ID is used. @@ -276,7 +272,6 @@ class ServerConfiguration(BaseModel): external_login_url: Optional[str] = None external_user_info_url: Optional[str] = None - external_cookie_name: Optional[str] = None external_server_id: Optional[UUID] = None rbac_implementation_source: Optional[str] = None @@ -370,14 +365,6 @@ def _validate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: "authentication scheme." ) - # If the authentication scheme is set to `EXTERNAL`, the - # external cookie name must be specified. - if not data.get("external_cookie_name"): - raise ValueError( - "The external cookie name must be specified when " - "using the EXTERNAL authentication scheme." - ) - if cors_allow_origins := data.get("cors_allow_origins"): origins = cors_allow_origins.split(",") data["cors_allow_origins"] = origins diff --git a/src/zenml/login/credentials.py b/src/zenml/login/credentials.py index 9da8e0b01c..529c883743 100644 --- a/src/zenml/login/credentials.py +++ b/src/zenml/login/credentials.py @@ -44,7 +44,6 @@ class APIToken(BaseModel): expires_in: Optional[int] = None expires_at: Optional[datetime] = None leeway: Optional[int] = None - cookie_name: Optional[str] = None device_id: Optional[UUID] = None device_metadata: Optional[Dict[str, Any]] = None diff --git a/src/zenml/login/credentials_store.py b/src/zenml/login/credentials_store.py index 34bfe440db..2287dae498 100644 --- a/src/zenml/login/credentials_store.py +++ b/src/zenml/login/credentials_store.py @@ -468,7 +468,6 @@ def set_token( expires_in=token_response.expires_in, expires_at=expires_at, leeway=leeway, - cookie_name=token_response.cookie_name, device_id=token_response.device_id, device_metadata=token_response.device_metadata, ) diff --git a/src/zenml/models/v2/misc/auth_models.py b/src/zenml/models/v2/misc/auth_models.py index b833811773..ab6f3a7876 100644 --- a/src/zenml/models/v2/misc/auth_models.py +++ b/src/zenml/models/v2/misc/auth_models.py @@ -121,7 +121,6 @@ class OAuthTokenResponse(BaseModel): refresh_token: Optional[str] = None csrf_token: Optional[str] = None scope: Optional[str] = None - cookie_name: Optional[str] = None device_id: Optional[UUID] = None device_metadata: Optional[Dict[str, Any]] = None diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index 083e690a35..e7420dbcdc 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -194,7 +194,7 @@ def _clear_session(self) -> None: self._token_expires_at = None def _fetch_auth_token(self) -> str: - """Fetch an auth token for the Cloud API from auth0. + """Fetch an auth token from the Cloud API. Raises: RuntimeError: If the auth token can't be fetched. @@ -210,7 +210,7 @@ def _fetch_auth_token(self) -> str: ): return self._token - # Get an auth token from auth0 + # Get an auth token from the Cloud API login_url = f"{self._config.api_url}/auth/login" headers = {"content-type": "application/x-www-form-urlencoded"} payload = { @@ -225,7 +225,9 @@ def _fetch_auth_token(self) -> str: ) response.raise_for_status() except Exception as e: - raise RuntimeError(f"Error fetching auth token from auth0: {e}") + raise RuntimeError( + f"Error fetching auth token from the Cloud API: {e}" + ) json_response = response.json() access_token = json_response.get("access_token", "") @@ -237,7 +239,9 @@ def _fetch_auth_token(self) -> str: or not expires_in or not isinstance(expires_in, int) ): - raise RuntimeError("Could not fetch auth token from auth0.") + raise RuntimeError( + "Could not fetch auth token from the Cloud API." + ) self._token = access_token self._token_expires_at = datetime.now(timezone.utc) + timedelta( diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index 2f64d5881c..8d73f91e42 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -190,9 +190,6 @@ external_login_url: {{ .ZenML.auth.externalLoginURL | quote }} {{- if .ZenML.auth.externalUserInfoURL }} external_user_info_url: {{ .ZenML.auth.externalUserInfoURL | quote }} {{- end }} -{{- if .ZenML.auth.externalCookieName }} -external_cookie_name: {{ .ZenML.auth.externalCookieName | quote }} -{{- end }} {{- if .ZenML.auth.externalServerID }} external_server_id: {{ .ZenML.auth.externalServerID | quote }} {{- end }} diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index 5e742c1cb0..b11e6b46e6 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -159,11 +159,6 @@ zenml: # is set to `EXTERNAL`. externalUserInfoURL: - # The name of the http-only cookie used to store the bearer token used to - # authenticate with the external authenticator service. Only relevant if - # `zenml.auth.authType` is set to `EXTERNAL`. - externalCookieName: - # The UUID of the ZenML server to use with the `EXTERNAL` authentication # scheme. If not specified, the regular ZenML server ID (deployment ID) is # used. diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index 1754eccef1..63acc3b2a9 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -263,32 +263,21 @@ def token( cookie_response = None elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_EXTERNAL: - assert config.external_cookie_name is not None assert config.external_login_url is not None authorization_url = config.external_login_url - # First, try to get the external access token from the external cookie - external_access_token = request.cookies.get( - config.external_cookie_name - ) - if not external_access_token: - # Next, try to get the external access token from the authorization - # header - authorization_header = request.headers.get("Authorization") - if authorization_header: - scheme, _, token = authorization_header.partition(" ") - if token and scheme.lower() == "bearer": - external_access_token = token - logger.info( - "External access token found in authorization header." - ) - else: - logger.info("External access token found in cookie.") + # Try to get the external session token or authorization token from the + # authorization header + authorization_header = request.headers.get("Authorization") + if authorization_header: + scheme, _, token = authorization_header.partition(" ") + if token and scheme.lower() == "bearer": + external_access_token = token - if not external_access_token: + if not authorization_header: logger.info( - "External access token not found. Redirecting to " + "External session or authorization token not found. Redirecting to " "external authenticator." ) From b9096af4fd8bbeba67f3ab5ba9f3bcfed93420a5 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 16 Dec 2024 19:32:49 +0100 Subject: [PATCH 04/13] Allow connecting to external servers enrolled as ZenML Pro tenants --- src/zenml/cli/login.py | 18 ++++++------ src/zenml/models/v2/misc/server_models.py | 8 ++++++ src/zenml/zen_server/auth.py | 2 +- src/zenml/zen_stores/rest_zen_store.py | 34 +++++++++++------------ 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/src/zenml/cli/login.py b/src/zenml/cli/login.py index 486be5baea..4cddc93bd3 100644 --- a/src/zenml/cli/login.py +++ b/src/zenml/cli/login.py @@ -145,6 +145,7 @@ def connect_to_server( api_key: Optional[str] = None, verify_ssl: Union[str, bool] = True, refresh: bool = False, + pro_server: bool = False, ) -> None: """Connect the client to a ZenML server or a SQL database. @@ -154,6 +155,7 @@ def connect_to_server( verify_ssl: Whether to verify the server's TLS certificate. If a string is passed, it is interpreted as the path to a CA bundle file. refresh: Whether to force a new login flow with the ZenML server. + pro_server: Whether the server is a ZenML Pro server. """ from zenml.login.credentials_store import get_credentials_store from zenml.zen_stores.base_zen_store import BaseZenStore @@ -170,7 +172,12 @@ def connect_to_server( f"Authenticating to ZenML server '{url}' using an API key..." ) credentials_store.set_api_key(url, api_key) - elif not is_zenml_pro_server_url(url): + elif pro_server: + # We don't have to do anything here assuming the user has already + # logged in to the ZenML Pro server using the ZenML Pro web login + # flow. + cli_utils.declare(f"Authenticating to ZenML server '{url}'...") + else: if refresh or not credentials_store.has_valid_authentication(url): cli_utils.declare( f"Authenticating to ZenML server '{url}' using the web " @@ -179,11 +186,6 @@ def connect_to_server( web_login(url=url, verify_ssl=verify_ssl) else: cli_utils.declare(f"Connecting to ZenML server '{url}'...") - else: - # We don't have to do anything here assuming the user has already - # logged in to the ZenML Pro server using the ZenML Pro web login - # flow. - cli_utils.declare(f"Authenticating to ZenML server '{url}'...") rest_store_config = RestZenStoreConfiguration( url=url, @@ -277,7 +279,7 @@ def connect_to_pro_server( # server to connect to. if api_key: if server_url: - connect_to_server(server_url, api_key=api_key) + connect_to_server(server_url, api_key=api_key, pro_server=True) return else: raise ValueError( @@ -405,7 +407,7 @@ def connect_to_pro_server( f"Connecting to ZenML Pro server: {server.name} [{str(server.id)}] " ) - connect_to_server(server.url, api_key=api_key) + connect_to_server(server.url, api_key=api_key, pro_server=True) # Update the stored server info with more accurate data taken from the # ZenML Pro tenant object. diff --git a/src/zenml/models/v2/misc/server_models.py b/src/zenml/models/v2/misc/server_models.py index be548189c3..d7e619020a 100644 --- a/src/zenml/models/v2/misc/server_models.py +++ b/src/zenml/models/v2/misc/server_models.py @@ -119,6 +119,14 @@ def is_local(self) -> bool: # server ID is the same as the local client (user) ID. return self.id == GlobalConfiguration().user_id + def is_pro_server(self) -> bool: + """Return whether the server is a ZenML Pro server. + + Returns: + True if the server is a ZenML Pro server, False otherwise. + """ + return self.deployment_type == ServerDeploymentType.CLOUD + class ServerLoadInfo(BaseModel): """Domain model for ZenML server load information.""" diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index e02ac10dca..4b5a74e943 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -1022,12 +1022,12 @@ async def __call__(self, request: Request) -> Optional[str]: def oauth2_authentication( + request: Request, token: str = Depends( CookieOAuth2TokenBearer( tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN, ) ), - request: Request = Depends(), ) -> AuthContext: """Authenticates any request to the ZenML server with OAuth2 JWT tokens. diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index a7f013904a..74f6961e9e 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -450,6 +450,7 @@ class RestZenStore(BaseZenStore): CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = RestZenStoreConfiguration _api_token: Optional[APIToken] = None _session: Optional[requests.Session] = None + _server_info: Optional[ServerModel] = None # ==================================== # ZenML Store interface implementation @@ -469,7 +470,7 @@ def _initialize(self) -> None: """ try: client_version = zenml.__version__ - server_version = self.get_store_info().version + server_version = self.server_info.version # Handle cases where the ZenML server is not available except ConnectionError as e: @@ -522,6 +523,17 @@ def _initialize(self) -> None: ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING, ) + @property + def server_info(self) -> ServerModel: + """Get cached information about the server. + + Returns: + Cached information about the server. + """ + if self._server_info is None: + return self.get_store_info() + return self._server_info + def get_store_info(self) -> ServerModel: """Get information about the server. @@ -529,7 +541,8 @@ def get_store_info(self) -> ServerModel: Information about the server. """ body = self.get(INFO) - return ServerModel.model_validate(body) + self._server_info = ServerModel.model_validate(body) + return self._server_info def get_deployment_id(self) -> UUID: """Get the ID of the deployment. @@ -537,7 +550,7 @@ def get_deployment_id(self) -> UUID: Returns: The ID of the deployment. """ - return self.get_store_info().id + return self.server_info.id # -------------------- Server Settings -------------------- @@ -4028,19 +4041,6 @@ def get_or_generate_api_token(self) -> str: token = credentials.api_token if credentials else None if credentials and token and not token.expired: self._api_token = token - - # Populate the server info in the credentials store if it is - # not already present - if not credentials.server_id: - try: - server_info = self.get_store_info() - except Exception as e: - logger.warning(f"Failed to get server info: {e}.") - else: - credentials_store.update_server_info( - self.url, server_info - ) - return self._api_token.access_token # Token is expired or not found in the cache. Time to get a new one. @@ -4084,7 +4084,7 @@ def get_or_generate_api_token(self) -> str: "username": username, "password": password, } - elif is_zenml_pro_server_url(self.url): + elif self.server_info.is_pro_server(): # ZenML Pro tenants use a proprietary authorization grant # where the ZenML Pro API session token is exchanged for a # regular ZenML server access token. From 33af3cefcb01eaa4e840b5b512b592df9e7313e9 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 20 Dec 2024 09:32:39 +0100 Subject: [PATCH 05/13] Updated support for external ZenML Pro server enrollment --- examples/e2e/pipelines/training.py | 13 +- examples/quickstart/pipelines/training.py | 8 +- src/zenml/config/server_config.py | 71 ++++++++++- src/zenml/constants.py | 1 + .../orchestrators/sagemaker_orchestrator.py | 8 +- .../wandb_experiment_tracker_flavor.py | 6 +- src/zenml/zen_server/auth.py | 10 ++ src/zenml/zen_server/cloud_utils.py | 61 ++------- .../deploy/helm/templates/_environment.tpl | 116 ++++++++++++++---- .../deploy/helm/templates/server-secret.yaml | 3 + src/zenml/zen_server/deploy/helm/values.yaml | 68 +++++++++- .../feature_gate/feature_gate_interface.py | 2 +- .../zen_server/rbac/rbac_sql_zen_store.py | 7 +- .../zen_server/routers/auth_endpoints.py | 13 ++ .../zen_server/template_execution/utils.py | 7 +- src/zenml/zen_server/utils.py | 19 +++ .../1cb6477f72d6_move_artifact_save_type.py | 30 +++-- ...557b2871693_update_step_run_input_types.py | 12 +- .../cc269488e5a9_separate_run_metadata.py | 18 ++- 19 files changed, 355 insertions(+), 118 deletions(-) diff --git a/examples/e2e/pipelines/training.py b/examples/e2e/pipelines/training.py index ba9a2f7548..52b6ab189b 100644 --- a/examples/e2e/pipelines/training.py +++ b/examples/e2e/pipelines/training.py @@ -119,12 +119,13 @@ def e2e_use_case_training( target=target, ) ########## Promotion stage ########## - latest_metric, current_metric = ( - compute_performance_metrics_on_current_data( - dataset_tst=dataset_tst, - target_env=target_env, - after=["model_evaluator"], - ) + ( + latest_metric, + current_metric, + ) = compute_performance_metrics_on_current_data( + dataset_tst=dataset_tst, + target_env=target_env, + after=["model_evaluator"], ) promote_with_metric_compare( diff --git a/examples/quickstart/pipelines/training.py b/examples/quickstart/pipelines/training.py index 55439cc582..2f8e9ff915 100644 --- a/examples/quickstart/pipelines/training.py +++ b/examples/quickstart/pipelines/training.py @@ -47,9 +47,11 @@ def english_translation_pipeline( tokenized_dataset, tokenizer = tokenize_data( dataset=full_dataset, model_type=model_type ) - tokenized_train_dataset, tokenized_eval_dataset, tokenized_test_dataset = ( - split_dataset(tokenized_dataset) - ) + ( + tokenized_train_dataset, + tokenized_eval_dataset, + tokenized_test_dataset, + ) = split_dataset(tokenized_dataset) model = train_model( tokenized_dataset=tokenized_train_dataset, model_type=model_type, diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index a9ad983645..5f82c501d0 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Functionality to support ZenML GlobalConfiguration.""" +"""Functionality to support ZenML Server Configuration.""" import json import os @@ -19,7 +19,14 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + model_validator, + field_validator, +) from zenml.constants import ( DEFAULT_ZENML_JWT_TOKEN_ALGORITHM, @@ -44,6 +51,7 @@ DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP, DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE, ENV_ZENML_SERVER_PREFIX, + ENV_ZENML_SERVER_PRO_PREFIX, ) from zenml.enums import AuthScheme from zenml.logger import get_logger @@ -539,6 +547,9 @@ def get_server_config(cls) -> "ServerConfiguration": for k, v in os.environ.items(): if v == "": continue + if k.startswith(ENV_ZENML_SERVER_PRO_PREFIX): + # Skip Pro configuration + continue if k.startswith(ENV_ZENML_SERVER_PREFIX): env_server_config[ k[len(ENV_ZENML_SERVER_PREFIX) :].lower() @@ -551,3 +562,59 @@ def get_server_config(cls) -> "ServerConfiguration": # permit downgrading extra="allow", ) + + +class ServerProConfiguration(BaseModel): + """ZenML Server Pro configuration attributes. + + All these attributes can be set through the environment with the + `ZENML_SERVER_PRO_`-Prefix. E.g. the value of the `ZENML_SERVER_PRO_API_URL` + environment variable will be extracted to api_url. + + Attributes: + api_url: The ZenML Pro API URL. + oauth2_client_secret: The ZenML Pro OAuth2 client secret used to + authenticate the ZenML server with the ZenML Pro API. + oauth2_audience: The OAuth2 audience. + """ + + api_url: str + oauth2_client_secret: str + oauth2_audience: str + + @field_validator("api_url") + @classmethod + def _strip_trailing_slashes_url(cls, url: str) -> str: + """Strip any trailing slashes on the API URL. + + Args: + url: The API URL. + + Returns: + The API URL with potential trailing slashes removed. + """ + return url.rstrip("/") + + @classmethod + def get_server_config(cls) -> "ServerProConfiguration": + """Get the server Pro configuration. + + Returns: + The server Pro configuration. + """ + env_server_config: Dict[str, Any] = {} + for k, v in os.environ.items(): + if v == "": + continue + if k.startswith(ENV_ZENML_SERVER_PRO_PREFIX): + env_server_config[ + k[len(ENV_ZENML_SERVER_PRO_PREFIX) :].lower() + ] = v + + return ServerProConfiguration(**env_server_config) + + model_config = ConfigDict( + # Allow extra attributes from configs of previous ZenML versions to + # permit downgrading + extra="allow", + ) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 183b1acce1..21b50425b8 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -176,6 +176,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # ZenML Server environment variables ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_" +ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_" ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE" ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME" ENV_ZENML_SERVER_REPORTABLE_RESOURCES = ( diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index f832647a97..46f89424d9 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -632,9 +632,11 @@ def _compute_orchestrator_url( the URL to the dashboard view in SageMaker. """ try: - region_name, pipeline_name, execution_id = ( - dissect_pipeline_execution_arn(pipeline_execution.arn) - ) + ( + region_name, + pipeline_name, + execution_id, + ) = dissect_pipeline_execution_arn(pipeline_execution.arn) # Get the Sagemaker session session = pipeline_execution.sagemaker_session diff --git a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py index 7a1a732170..96f366af79 100644 --- a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +++ b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py @@ -23,7 +23,7 @@ cast, ) -from pydantic import field_validator, BaseModel +from pydantic import BaseModel, field_validator from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( @@ -69,8 +69,8 @@ def _convert_settings(cls, value: Any) -> Any: import wandb if isinstance(value, wandb.Settings): - # Depending on the wandb version, either `model_dump`, - # `make_static` or `to_dict` is available to convert the settings + # Depending on the wandb version, either `model_dump`, + # `make_static` or `to_dict` is available to convert the settings # to a dictionary if isinstance(value, BaseModel): return value.model_dump() diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 4b5a74e943..3287bcaaa4 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -67,6 +67,7 @@ from zenml.zen_server.exceptions import http_exception_from_error from zenml.zen_server.jwt import JWTToken from zenml.zen_server.utils import ( + get_zenml_headers, is_same_or_subdomain, server_config, zen_store, @@ -307,6 +308,14 @@ def authenticate_credentials( device_model: Optional[OAuthDeviceInternalResponse] = None if decoded_token.device_id: + if server_config().auth_scheme in [ + AuthScheme.NO_AUTH, + AuthScheme.EXTERNAL, + ]: + error = "Authentication error: device authorization is not supported." + logger.error(error) + raise CredentialsNotValid(error) + # Access tokens that have been issued for a device are only valid # for that device, so we need to check if the device ID matches any # of the valid devices in the database. @@ -685,6 +694,7 @@ def authenticate_external_user( # Get the user information from the external authenticator user_info_url = config.external_user_info_url headers = {"Authorization": "Bearer " + external_access_token} + headers.update(get_zenml_headers()) query_params = dict(server_id=str(config.get_external_server_id())) try: diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index e7420dbcdc..6a52cc4877 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -1,71 +1,23 @@ """Utils concerning anything concerning the cloud control plane backend.""" -import os from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional import requests -from pydantic import BaseModel, ConfigDict, field_validator from requests.adapters import HTTPAdapter, Retry +from zenml.config.server_config import ServerProConfiguration from zenml.exceptions import SubscriptionUpgradeRequiredError -from zenml.zen_server.utils import server_config - -ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" +from zenml.zen_server.utils import get_zenml_headers, server_config _cloud_connection: Optional["ZenMLCloudConnection"] = None - -class ZenMLCloudConfiguration(BaseModel): - """ZenML Pro RBAC configuration.""" - - api_url: str - oauth2_client_id: str - oauth2_client_secret: str - oauth2_audience: str - - @field_validator("api_url") - @classmethod - def _strip_trailing_slashes_url(cls, url: str) -> str: - """Strip any trailing slashes on the API URL. - - Args: - url: The API URL. - - Returns: - The API URL with potential trailing slashes removed. - """ - return url.rstrip("/") - - @classmethod - def from_environment(cls) -> "ZenMLCloudConfiguration": - """Get the RBAC configuration from environment variables. - - Returns: - The RBAC configuration. - """ - env_config: Dict[str, Any] = {} - for k, v in os.environ.items(): - if v == "": - continue - if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): - env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v - - return ZenMLCloudConfiguration(**env_config) - - model_config = ConfigDict( - # Allow extra attributes from configs of previous ZenML versions to - # permit downgrading - extra="allow" - ) - - class ZenMLCloudConnection: """Class to use for communication between server and control plane.""" def __init__(self) -> None: """Initialize the RBAC component.""" - self._config = ZenMLCloudConfiguration.from_environment() + self._config = ServerProConfiguration.get_server_config() self._session: Optional[requests.Session] = None self._token: Optional[str] = None self._token_expires_at: Optional[datetime] = None @@ -169,6 +121,8 @@ def session(self) -> requests.Session: self._session = requests.Session() token = self._fetch_auth_token() self._session.headers.update({"Authorization": "Bearer " + token}) + # Add the ZenML specific headers + self._session.headers.update(get_zenml_headers()) retries = Retry( total=5, backoff_factor=0.1, status_forcelist=[502, 504] @@ -213,8 +167,11 @@ def _fetch_auth_token(self) -> str: # Get an auth token from the Cloud API login_url = f"{self._config.api_url}/auth/login" headers = {"content-type": "application/x-www-form-urlencoded"} + # Add zenml specific headers to the request + headers.update(get_zenml_headers()) payload = { - "client_id": self._config.oauth2_client_id, + # The client ID is the external server ID + "client_id": str(server_config().get_external_server_id()), "client_secret": self._config.oauth2_client_secret, "audience": self._config.oauth2_audience, "grant_type": "client_credentials", diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index 8d73f91e42..f100f87cc0 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -140,8 +140,55 @@ Returns: A dictionary with the non-secret values configured for the ZenML server. */}} {{- define "zenml.serverConfigurationAttrs" -}} + +{{- if .ZenML.pro.enabled }} + +auth_scheme: "external" +deployment_type: cloud +cors_allow_origins: "{{ .ZenML.pro.dashboardURL }},{{ .ZenML.pro.serverURL }}" +external_login_url: "{{ .ZenML.pro.dashboardURL }}/api/auth/login" +external_user_info_url: "{{ .ZenML.pro.dashboardURL }}/users/authorize_server" +external_server_id: {{ .ZenML.pro.tenantID | quote }} +jwt_token_expire_minutes: "60" +rbac_implementation_source: "zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC" +feature_gate_implementation_source: "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface" +dashboard_url: "{{ .ZenML.pro.dashboardURL }}/organizations/{{ .ZenML.pro.organizationID }}/tenants/{{ .ZenML.pro.tenantID }}" +metadata: '{"account_id":"{{ .ZenML.pro.organizationID }}","organization_id": "{{ .ZenML.pro.organizationID }}","tenant_id":"{{ .ZenML.pro.tenantID }}"}' +reportable_resources: '["pipeline","pipeline_run","model"]' +pro_api_url: "{{ .ZenML.pro.apiURL }}" +pro_oauth2_audience: "{{ .ZenML.pro.apiURL }}" + +{{- else }} + auth_scheme: {{ .ZenML.authType | default .ZenML.auth.authType | quote }} deployment_type: {{ .ZenML.deploymentType | default "kubernetes" }} +{{- if .ZenML.auth.corsAllowOrigins }} +cors_allow_origins: {{ join "," .ZenML.auth.corsAllowOrigins | quote }} +{{- end }} +{{- if .ZenML.auth.externalLoginURL }} +external_login_url: {{ .ZenML.auth.externalLoginURL | quote }} +{{- end }} +{{- if .ZenML.auth.externalUserInfoURL }} +external_user_info_url: {{ .ZenML.auth.externalUserInfoURL | quote }} +{{- end }} +{{- if .ZenML.auth.externalServerID }} +external_server_id: {{ .ZenML.auth.externalServerID | quote }} +{{- end }} +{{- if .ZenML.auth.jwtTokenExpireMinutes }} +jwt_token_expire_minutes: {{ .ZenML.auth.jwtTokenExpireMinutes | quote }} +{{- end }} +{{- if .ZenML.auth.rbacImplementationSource }} +rbac_implementation_source: {{ .ZenML.auth.rbacImplementationSource | quote }} +{{- end }} +{{- if .ZenML.auth.featureGateImplementationSource }} +feature_gate_implementation_source: {{ .ZenML.auth.featureGateImplementationSource | quote }} +{{- end }} +{{- if .ZenML.dashboardURL }} +dashboard_url: {{ .ZenML.dashboardURL | quote }} +{{- end }} + +{{- end }} + {{- if .ZenML.threadPoolSize }} thread_pool_size: {{ .ZenML.threadPoolSize | quote }} {{- end }} @@ -157,18 +204,12 @@ jwt_token_audience: {{ .ZenML.auth.jwtTokenAudience | quote }} {{- if .ZenML.auth.jwtTokenLeewaySeconds }} jwt_token_leeway_seconds: {{ .ZenML.auth.jwtTokenLeewaySeconds | quote }} {{- end }} -{{- if .ZenML.auth.jwtTokenExpireMinutes }} -jwt_token_expire_minutes: {{ .ZenML.auth.jwtTokenExpireMinutes | quote }} -{{- end }} {{- if .ZenML.auth.authCookieName }} auth_cookie_name: {{ .ZenML.auth.authCookieName | quote }} {{- end }} {{- if .ZenML.auth.authCookieDomain }} auth_cookie_domain: {{ .ZenML.auth.authCookieDomain | quote }} {{- end }} -{{- if .ZenML.auth.corsAllowOrigins }} -cors_allow_origins: {{ join "," .ZenML.auth.corsAllowOrigins | quote }} -{{- end }} {{- if .ZenML.auth.maxFailedDeviceAuthAttempts }} max_failed_device_auth_attempts: {{ .ZenML.auth.maxFailedDeviceAuthAttempts | quote }} {{- end }} @@ -184,33 +225,45 @@ device_expiration_minutes: {{ .ZenML.auth.deviceExpirationMinutes | quote }} {{- if .ZenML.auth.trustedDeviceExpirationMinutes }} trusted_device_expiration_minutes: {{ .ZenML.auth.trustedDeviceExpirationMinutes | quote }} {{- end }} -{{- if .ZenML.auth.externalLoginURL }} -external_login_url: {{ .ZenML.auth.externalLoginURL | quote }} -{{- end }} -{{- if .ZenML.auth.externalUserInfoURL }} -external_user_info_url: {{ .ZenML.auth.externalUserInfoURL | quote }} -{{- end }} -{{- if .ZenML.auth.externalServerID }} -external_server_id: {{ .ZenML.auth.externalServerID | quote }} -{{- end }} {{- if .ZenML.rootUrlPath }} root_url_path: {{ .ZenML.rootUrlPath | quote }} {{- end }} {{- if .ZenML.serverURL }} server_url: {{ .ZenML.serverURL | quote }} {{- end }} -{{- if .ZenML.dashboardURL }} -dashboard_url: {{ .ZenML.dashboardURL | quote }} -{{- end }} -{{- if .ZenML.auth.rbacImplementationSource }} -rbac_implementation_source: {{ .ZenML.auth.rbacImplementationSource | quote }} -{{- end }} {{- range $key, $value := .ZenML.secure_headers }} secure_headers_{{ $key }}: {{ $value | quote }} {{- end }} {{- end }} +{{/* +ZenML server configuration options (secret values). + +This template constructs a dictionary that is similar to the python values that +can be configured in the zenml.config.server_config.ServerConfiguration +class. Only secret values are included in this dictionary. + +The dictionary is then converted into deployment environment variables by other +templates and inserted where it is needed. + +The input is taken from a .ZenML dict that is passed to the template and +contains the values configured in the values.yaml file for the ZenML server. + +Args: + .ZenML: A dictionary with the ZenML configuration values configured for the + ZenML server. +Returns: + A dictionary with the secret values configured for the ZenML server. +*/}} +{{- define "zenml.serverSecretConfigurationAttrs" -}} + +{{- if .ZenML.pro.enabled }} +pro_oauth2_client_secret: {{ .ZenML.pro.enrollmentKey | quote }} +{{- end }} +{{- end }} + + {{/* Server configuration environment variables (non-secret values). @@ -232,6 +285,27 @@ ZENML_SERVER_{{ $k | upper }}: {{ $v | quote }} {{- end }} +{{/* +Server configuration environment variables (secret values). + +Passes the .Values.zenml dict as input to the `zenml.serverSecretConfigurationAttrs` +template and converts the output into a dictionary of environment variables that +need to be configured for the server. + +Args: + .Values: The values.yaml file for the ZenML deployment. +Returns: + A dictionary with the secret environment variables that are configured for + the server (i.e. keys starting with `ZENML_SERVER_`). +*/}} +{{- define "zenml.serverSecretEnvVariables" -}} +{{ $zenml := dict "ZenML" .Values.zenml }} +{{- range $k, $v := include "zenml.serverSecretConfigurationAttrs" $zenml | fromYaml }} +ZENML_SERVER_{{ $k | upper }}: {{ $v | quote }} +{{- end }} +{{- end }} + + {{/* Secrets store configuration options (non-secret values). diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index 9ebf401b51..fcf2484177 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -13,6 +13,9 @@ data: {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} + {{- range $k, $v := include "zenml.serverSecretConfigurationAttrs" . | fromYaml}} + {{ $k }}: {{ $v | b64enc | quote }} + {{- end }} {{- range $k, $v := include "zenml.secretsStoreSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index b11e6b46e6..bea454d7ad 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -27,8 +27,33 @@ zenml: # Overrides the image tag whose default is the chart appVersion. tag: + # ZenML Pro related options. + pro: + # Set `enabled` to true to enable ZenML Pro servers. If set, some of the + # configuration options in the `zenml` section will be overridden with + # values specific to ZenML Pro servers computed from the values set in the + # `pro` section. + enabled: false + + # The URL where the ZenML Pro server API is reachable. + apiURL: https://cloudapi.zenml.io + + # The URL where the ZenML Pro dashboard is reachable. + dashboardURL: https://cloud.zenml.io + + # The ID of the ZenML Pro tenant to use. + tenantID: + + # The ID of the ZenML Pro organization to use. + organizationID: + + # The enrollment key to use for the ZenML Pro tenant. + enrollmentKey: + # The URL where the ZenML server API is reachable. If not specified, the # clients will use the same URL used to connect them to the ZenML server. + # + # IMPORTANT: this value must be set for ZenML Pro servers. serverURL: # The URL where the ZenML dashboard is reachable. @@ -39,6 +64,8 @@ zenml: # This is value is used to compute the dashboard URLs during the web login # authentication workflow, to print dashboard URLs in log messages when # running a pipeline and for other similar tasks. + # + # This value is overridden if the `zenml.pro.enabled` value is set. dashboardURL: debug: true @@ -48,6 +75,8 @@ zenml: # ZenML server deployment type. This field is used for telemetry purposes. # Example values are "local", "kubernetes", "aws", "gcp", "azure". + # + # This value is overridden if the `zenml.pro.enabled` value is set. deploymentType: # Authentication settings that control how the ZenML server authenticates @@ -60,6 +89,8 @@ zenml: # HTTP_BASIC - HTTP Basic authentication # OAUTH2_PASSWORD_BEARER - OAuth2 password bearer # EXTERNAL - External authentication (e.g. via a remote authenticator) + # + # This value is overridden if the `zenml.pro.enabled` value is set. authType: OAUTH2_PASSWORD_BEARER # The secret key used to sign JWT tokens. This should be set to @@ -91,7 +122,6 @@ zenml: # ZenML Server ID. jwtTokenIssuer: - # The audience of the JWT tokens. If not specified, the audience is set to # the ZenML Server ID. jwtTokenAudience: @@ -102,6 +132,8 @@ zenml: # The expiration time of JWT tokens in minutes. If not specified, generated # JWT tokens will not be set to expire. + # + # This value is automatically set if the `zenml.pro.enabled` value is set. jwtTokenExpireMinutes: # The name of the http-only cookie used to store the JWT tokens used to @@ -117,20 +149,31 @@ zenml: # The origins allowed to make cross-origin requests to the ZenML server. If # not specified, all origins are allowed. Set this when the ZenML dashboard # is hosted on a different domain than the ZenML server. + # + # This value is overridden if the `zenml.pro.enabled` value is set. corsAllowOrigins: - "*" # The maximum number of failed authentication attempts allowed for an OAuth # 2.0 device before the device is locked. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. maxFailedDeviceAuthAttempts: 3 # The timeout in seconds after which a pending OAuth 2.0 device # authorization request expires. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. deviceAuthTimeout: 300 # The polling interval in seconds used by clients to poll the OAuth 2.0 # device authorization endpoint for the status of a pending device # authorization request. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. deviceAuthPollingInterval: 5 # The time in minutes that an OAuth 2.0 device is allowed to be used to @@ -139,6 +182,9 @@ zenml: # be used indefinitely. This controls the expiration time of the JWT tokens # issued to clients after they have authenticated with the ZenML server # using an OAuth 2.0 device. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. deviceExpirationMinutes: # The time in minutes that a trusted OAuth 2.0 device is allowed to be used @@ -147,28 +193,46 @@ zenml: # be used indefinitely. This controls the expiration time of the JWT tokens # issued to clients after they have authenticated with the ZenML server # using an OAuth 2.0 device that was previously trusted by the user. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. trustedDeviceExpirationMinutes: # The login URL of an external authenticator service to use with the # `EXTERNAL` authentication scheme. Only relevant if `zenml.auth.authType` # is set to `EXTERNAL`. + # + # This value is overridden if the `zenml.pro.enabled` value is set. externalLoginURL: # The user info URL of an external authenticator service to use with the # `EXTERNAL` authentication scheme. Only relevant if `zenml.auth.authType` # is set to `EXTERNAL`. + # + # This value is overridden if the `zenml.pro.enabled` value is set. externalUserInfoURL: # The UUID of the ZenML server to use with the `EXTERNAL` authentication # scheme. If not specified, the regular ZenML server ID (deployment ID) is # used. + # + # This value is overridden if the `zenml.pro.enabled` value is set. externalServerID: # Source pointing to a class implementing the RBAC interface defined by - # `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, + # `zenml.zen_server.rbac.rbac_interface.RBACInterface`. If not specified, # RBAC will not be enabled for this server. + # + # This value is overridden if the `zenml.pro.enabled` value is set. rbacImplementationSource: + # Source pointing to a class implementing the feature gate interface defined + # by `zenml.zen_server.feature_gate.feature_gate_interface.FeatureGateInterface`. + # If not specified, feature gating will not be enabled for this server. + # + # This value is overridden if the `zenml.pro.enabled` value is set. + featureGateImplementationSource: + # The root URL path to use when behind a proxy. This is useful when the # `rewrite-target` annotation is used in the ingress controller, e.g.: # diff --git a/src/zenml/zen_server/feature_gate/feature_gate_interface.py b/src/zenml/zen_server/feature_gate/feature_gate_interface.py index df4a5d3fc7..5a93fecc34 100644 --- a/src/zenml/zen_server/feature_gate/feature_gate_interface.py +++ b/src/zenml/zen_server/feature_gate/feature_gate_interface.py @@ -20,7 +20,7 @@ class FeatureGateInterface(ABC): - """RBAC interface definition.""" + """Feature gate interface definition.""" @abstractmethod def check_entitlement(self, resource: ResourceType) -> None: diff --git a/src/zenml/zen_server/rbac/rbac_sql_zen_store.py b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py index 1d6082a9e7..b688060b25 100644 --- a/src/zenml/zen_server/rbac/rbac_sql_zen_store.py +++ b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py @@ -152,8 +152,11 @@ def _get_or_create_model_version( error = e if allow_creation: - created, model_version_response = ( - super()._get_or_create_model_version(model_version_request, producer_run_id=producer_run_id) + ( + created, + model_version_response, + ) = super()._get_or_create_model_version( + model_version_request, producer_run_id=producer_run_id ) else: try: diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index 63acc3b2a9..f837c8e000 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -175,6 +175,12 @@ def __init__( self.username = username self.password = password or "" elif self.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE: + if config.auth_scheme in [AuthScheme.NO_AUTH, AuthScheme.EXTERNAL]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Device authorization is not supported.", + ) + if not device_code or not client_id: logger.info("Request with missing device code or client ID") raise HTTPException( @@ -356,6 +362,13 @@ def device_authorization( The device authorization response. """ config = server_config() + + if config.auth_scheme in [AuthScheme.NO_AUTH, AuthScheme.EXTERNAL]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Device authorization is not supported.", + ) + store = zen_store() try: diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index 33ef74b644..d2bdb5241d 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -149,9 +149,10 @@ def run_template( ) def _task() -> None: - pypi_requirements, apt_packages = ( - requirements_utils.get_requirements_for_stack(stack=stack) - ) + ( + pypi_requirements, + apt_packages, + ) = requirements_utils.get_requirements_for_stack(stack=stack) if build.python_version: version_info = version.parse(build.python_version) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 715fe78c13..65c8151625 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -20,6 +20,7 @@ TYPE_CHECKING, Any, Callable, + Dict, List, Optional, Tuple, @@ -31,6 +32,7 @@ from pydantic import BaseModel, ValidationError +from zenml import __version__ as zenml_version from zenml.config.global_config import GlobalConfiguration from zenml.config.server_config import ServerConfiguration from zenml.constants import ( @@ -616,3 +618,20 @@ def is_same_or_subdomain(source_domain: str, target_domain: str) -> bool: return True # Subdomain of subdomain return False + + +def get_zenml_headers() -> Dict[str, str]: + """Get the ZenML specific headers to be included in requests made by the server. + + Returns: + The ZenML specific headers. + """ + config = server_config() + headers = { + "zenml-server-id": str(config.get_external_server_id()), + "zenml-server-version": zenml_version, + } + if config.server_url: + headers["zenml-server-url"] = config.server_url + + return headers diff --git a/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py b/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py index ff14523cdb..3e74d80e14 100644 --- a/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py +++ b/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py @@ -23,7 +23,8 @@ def upgrade() -> None: batch_op.add_column(sa.Column("save_type", sa.TEXT(), nullable=True)) # Step 2: Move data from step_run_output_artifact.type to artifact_version.save_type - op.execute(""" + op.execute( + """ UPDATE artifact_version SET save_type = ( SELECT max(step_run_output_artifact.type) @@ -31,17 +32,22 @@ def upgrade() -> None: WHERE step_run_output_artifact.artifact_id = artifact_version.id GROUP BY artifact_id ) - """) - op.execute(""" + """ + ) + op.execute( + """ UPDATE artifact_version SET save_type = 'step_output' WHERE artifact_version.save_type = 'default' - """) - op.execute(""" + """ + ) + op.execute( + """ UPDATE artifact_version SET save_type = 'external' WHERE save_type is NULL - """) + """ + ) # # Step 3: Set save_type to non-nullable with op.batch_alter_table("artifact_version", schema=None) as batch_op: @@ -69,7 +75,8 @@ def downgrade() -> None: ) # Move data back from artifact_version.save_type to step_run_output_artifact.type - op.execute(""" + op.execute( + """ UPDATE step_run_output_artifact SET type = ( SELECT max(artifact_version.save_type) @@ -77,12 +84,15 @@ def downgrade() -> None: WHERE step_run_output_artifact.artifact_id = artifact_version.id GROUP BY artifact_id ) - """) - op.execute(""" + """ + ) + op.execute( + """ UPDATE step_run_output_artifact SET type = 'default' WHERE step_run_output_artifact.type = 'step_output' - """) + """ + ) # Set type to non-nullable with op.batch_alter_table( diff --git a/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py b/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py index cf397f57d9..410a481e06 100644 --- a/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py +++ b/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py @@ -17,17 +17,21 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - op.execute(""" + op.execute( + """ UPDATE step_run_input_artifact SET type = 'step_output' WHERE type = 'default' - """) + """ + ) def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - op.execute(""" + op.execute( + """ UPDATE step_run_input_artifact SET type = 'default' WHERE type = 'step_output' - """) + """ + ) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 52a4cbd8ef..6dcfe8ee8d 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -41,10 +41,12 @@ def upgrade() -> None: connection = op.get_bind() run_metadata_data = connection.execute( - sa.text(""" + sa.text( + """ SELECT id, resource_id, resource_type FROM run_metadata - """) + """ + ) ).fetchall() # Prepare data with new UUIDs for bulk insert @@ -107,20 +109,24 @@ def downgrade() -> None: # Fetch data from `run_metadata_resource` run_metadata_resource_data = connection.execute( - sa.text(""" + sa.text( + """ SELECT resource_id, resource_type, run_metadata_id FROM run_metadata_resource - """) + """ + ) ).fetchall() # Update `run_metadata` with the data from `run_metadata_resource` for row in run_metadata_resource_data: connection.execute( - sa.text(""" + sa.text( + """ UPDATE run_metadata SET resource_id = :resource_id, resource_type = :resource_type WHERE id = :run_metadata_id - """), + """ + ), { "resource_id": row.resource_id, "resource_type": row.resource_type, From accbe6f3288c0d4e6f6def759aba829f48e60587 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 20 Dec 2024 13:28:29 +0100 Subject: [PATCH 06/13] Send status updates to ZenML Pro control plane on startup --- src/zenml/config/server_config.py | 2 +- src/zenml/zen_server/cloud_utils.py | 48 +++++++++++++++++++ .../deploy/helm/templates/_environment.tpl | 2 +- .../deploy/helm/templates/server-secret.yaml | 2 +- src/zenml/zen_server/zen_server_api.py | 5 ++ 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 5f82c501d0..963b83e316 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -24,8 +24,8 @@ ConfigDict, Field, PositiveInt, - model_validator, field_validator, + model_validator, ) from zenml.constants import ( diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index 6a52cc4877..b604636146 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -12,6 +12,7 @@ _cloud_connection: Optional["ZenMLCloudConnection"] = None + class ZenMLCloudConnection: """Class to use for communication between server and control plane.""" @@ -103,6 +104,48 @@ def post( return response + def patch( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> requests.Response: + """Send a PATCH request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + data: Data to include in the request. + + Raises: + RuntimeError: If the request failed. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.post( + url=url, params=params, json=data, timeout=7 + ) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.patch( + url=url, params=params, json=data, timeout=7 + ) + + try: + response.raise_for_status() + except requests.HTTPError as e: + raise RuntimeError( + f"Failed while trying to contact the central zenml pro " + f"service: {e}" + ) + + return response + @property def session(self) -> requests.Session: """Authenticate to the ZenML Pro Management Plane. @@ -220,3 +263,8 @@ def cloud_connection() -> ZenMLCloudConnection: _cloud_connection = ZenMLCloudConnection() return _cloud_connection + + +def send_pro_tenant_status_update() -> None: + """Send a tenant status update to the Cloud API.""" + cloud_connection().patch("/tenants/status_updates") diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index f100f87cc0..b0f724bceb 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -143,7 +143,7 @@ Returns: {{- if .ZenML.pro.enabled }} -auth_scheme: "external" +auth_scheme: EXTERNAL deployment_type: cloud cors_allow_origins: "{{ .ZenML.pro.dashboardURL }},{{ .ZenML.pro.serverURL }}" external_login_url: "{{ .ZenML.pro.dashboardURL }}/api/auth/login" diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index fcf2484177..4d3e0f2706 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -13,7 +13,7 @@ data: {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} - {{- range $k, $v := include "zenml.serverSecretConfigurationAttrs" . | fromYaml}} + {{- range $k, $v := include "zenml.serverSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} {{- range $k, $v := include "zenml.secretsStoreSecretEnvVariables" . | fromYaml}} diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 3c4b20638a..060e9db59c 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -54,6 +54,7 @@ ) from zenml.enums import AuthScheme, SourceContextTypes from zenml.models import ServerDeploymentType +from zenml.zen_server.cloud_utils import send_pro_tenant_status_update from zenml.zen_server.exceptions import error_detail from zenml.zen_server.routers import ( actions_endpoints, @@ -389,6 +390,10 @@ def initialize() -> None: initialize_plugins() initialize_secure_headers() initialize_memcache(cfg.memcache_max_capacity, cfg.memcache_default_expiry) + if cfg.deployment_type == ServerDeploymentType.CLOUD: + # Send a tenant status update to the Cloud API to indicate that the + # ZenML server is running or to update the version and server URL. + send_pro_tenant_status_update() DASHBOARD_REDIRECT_URL = None From 04daa7aa3bcec597769ae677ea78a1de90be61b0 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 20 Dec 2024 12:39:06 +0000 Subject: [PATCH 07/13] Auto-update of E2E template --- examples/e2e/pipelines/training.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/e2e/pipelines/training.py b/examples/e2e/pipelines/training.py index 52b6ab189b..ba9a2f7548 100644 --- a/examples/e2e/pipelines/training.py +++ b/examples/e2e/pipelines/training.py @@ -119,13 +119,12 @@ def e2e_use_case_training( target=target, ) ########## Promotion stage ########## - ( - latest_metric, - current_metric, - ) = compute_performance_metrics_on_current_data( - dataset_tst=dataset_tst, - target_env=target_env, - after=["model_evaluator"], + latest_metric, current_metric = ( + compute_performance_metrics_on_current_data( + dataset_tst=dataset_tst, + target_env=target_env, + after=["model_evaluator"], + ) ) promote_with_metric_compare( From 2e03e6166d8db90de79e4452f1f3105c3d9589ef Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 20 Dec 2024 21:10:57 +0100 Subject: [PATCH 08/13] Pushed zenml Pro config deeper into the server and reworked zenml login to allow on-prem deployments --- src/zenml/cli/login.py | 158 +++++++++++++----- src/zenml/cli/server.py | 29 +++- src/zenml/config/server_config.py | 77 ++++++++- src/zenml/constants.py | 12 +- src/zenml/login/credentials.py | 51 ++++-- src/zenml/login/credentials_store.py | 69 ++++++-- src/zenml/login/pro/client.py | 7 +- src/zenml/login/pro/constants.py | 6 - src/zenml/login/pro/utils.py | 36 ++-- src/zenml/login/server_info.py | 52 ++++++ src/zenml/login/web_login.py | 17 +- src/zenml/models/v2/misc/server_models.py | 36 ++++ src/zenml/zen_server/cloud_utils.py | 104 +++++------- .../deploy/helm/templates/_environment.tpl | 27 +-- src/zenml/zen_server/deploy/helm/values.yaml | 11 +- src/zenml/zen_server/rbac/endpoint_utils.py | 6 +- .../zen_server/routers/models_endpoints.py | 3 +- .../zen_server/routers/pipelines_endpoints.py | 4 +- .../zen_server/routers/server_endpoints.py | 1 + .../routers/workspaces_endpoints.py | 4 +- src/zenml/zen_stores/base_zen_store.py | 20 ++- src/zenml/zen_stores/rest_zen_store.py | 16 +- 22 files changed, 531 insertions(+), 215 deletions(-) create mode 100644 src/zenml/login/server_info.py diff --git a/src/zenml/cli/login.py b/src/zenml/cli/login.py index 4cddc93bd3..44a012fdbf 100644 --- a/src/zenml/cli/login.py +++ b/src/zenml/cli/login.py @@ -18,7 +18,7 @@ import re import sys import time -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from uuid import UUID import click @@ -34,7 +34,8 @@ IllegalOperationError, ) from zenml.logger import get_logger -from zenml.login.pro.utils import is_zenml_pro_server_url +from zenml.login.credentials import ServerType +from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.web_login import web_login from zenml.zen_server.utils import ( connected_to_local_server, @@ -227,6 +228,7 @@ def connect_to_pro_server( pro_server: Optional[str] = None, api_key: Optional[str] = None, refresh: bool = False, + pro_api_url: Optional[str] = None, ) -> None: """Connect the client to a ZenML Pro server. @@ -235,6 +237,7 @@ def connect_to_pro_server( If not provided, the web login flow will be initiated. api_key: The API key to use to authenticate with the ZenML Pro server. refresh: Whether to force a new login flow with the ZenML Pro server. + pro_api_url: The URL for the ZenML Pro API. Raises: ValueError: If incorrect parameters are provided. @@ -245,6 +248,8 @@ def connect_to_pro_server( from zenml.login.pro.client import ZenMLProClient from zenml.login.pro.tenant.models import TenantStatus + pro_api_url = pro_api_url or ZENML_PRO_API_URL + server_id, server_url, server_name = None, None, None login = False if not pro_server: @@ -266,14 +271,9 @@ def connect_to_pro_server( server_name = pro_server else: server_url = pro_server - if not is_zenml_pro_server_url(server_url): - raise ValueError( - f"The URL '{server_url}' does not seem to belong to a ZenML Pro " - "server. Please check the server URL and try again." - ) credentials_store = get_credentials_store() - if not credentials_store.has_valid_pro_authentication(): + if not credentials_store.has_valid_pro_authentication(pro_api_url): # Without valid ZenML Pro credentials, we can only connect to a ZenML # Pro server with an API key and we also need to know the URL of the # server to connect to. @@ -291,7 +291,9 @@ def connect_to_pro_server( if login or refresh: try: - token = web_login() + token = web_login( + pro_api_url=pro_api_url, + ) except AuthorizationException as e: cli_utils.error(f"Authorization error: {e}") @@ -322,7 +324,7 @@ def connect_to_pro_server( # server argument passed to the command. server_id = UUID(tenant_id) - client = ZenMLProClient() + client = ZenMLProClient(pro_api_url) if server_id: server = client.tenant.get(server_id) @@ -416,6 +418,42 @@ def connect_to_pro_server( cli_utils.declare(f"Connected to ZenML Pro server: {server.name}.") +def is_pro_server( + url: str, +) -> Tuple[Optional[bool], Optional[str]]: + """Check if the server at the given URL is a ZenML Pro server. + + Args: + url: The URL of the server to check. + + Returns: + True if the server is a ZenML Pro server, False otherwise, and the + extracted pro API URL if the server is a ZenML Pro server, or None if + no information could be extracted. + """ + from zenml.login.credentials_store import get_credentials_store + from zenml.login.server_info import get_server_info + + # First, check the credentials store + credentials_store = get_credentials_store() + credentials = credentials_store.get_credentials(url) + if credentials: + if credentials.type == ServerType.PRO: + return True, credentials.pro_api_url + else: + return False, None + + # Next, make a request to the server itself + server_info = get_server_info(url) + if not server_info: + return None, None + + if server_info.is_pro_server(): + return True, server_info.pro_api_url + + return False, None + + def _fail_if_authentication_environment_variables_set() -> None: """Fail if any of the authentication environment variables are set.""" environment_variables = [ @@ -652,6 +690,13 @@ def _fail_if_authentication_environment_variables_set() -> None: "dashboard on a public domain. Primarily used for accessing the " "dashboard in Colab. Only used when running `zenml login --local`.", ) +@click.option( + "--pro-api-url", + type=str, + default=None, + help="Custom URL for the ZenML Pro API. Useful when connecting " + "to a self-hosted ZenML Pro deployment.", +) def login( server: Optional[str] = None, pro: bool = False, @@ -669,6 +714,7 @@ def login( blocking: bool = False, image: Optional[str] = None, ngrok_token: Optional[str] = None, + pro_api_url: Optional[str] = None, ) -> None: """Connect to a remote ZenML server. @@ -693,6 +739,7 @@ def login( ngrok_token: An ngrok auth token to use for exposing the local ZenML dashboard on a public domain. Primarily used for accessing the dashboard in Colab. + pro_api_url: Custom URL for the ZenML Pro API. """ _fail_if_authentication_environment_variables_set() @@ -717,6 +764,7 @@ def login( connect_to_pro_server( pro_server=server, refresh=True, + pro_api_url=pro_api_url, ) return @@ -742,29 +790,45 @@ def login( ) if server is not None: - if is_zenml_pro_server_url(server) or not re.match( - r"^(https?|mysql)://", server - ): - # The server argument is a ZenML Pro server URL, server name or UUID + if not re.match(r"^(https?|mysql)://", server): + # The server argument is a ZenML Pro server name or UUID connect_to_pro_server( pro_server=server, api_key=api_key_value, refresh=refresh, + pro_api_url=pro_api_url, ) else: - connect_to_server( - url=server, - api_key=api_key_value, - verify_ssl=verify_ssl, - refresh=refresh, - ) + # The server argument is a server URL + + # First, try to discover if the server is a ZenML Pro server or not + server_is_pro, server_pro_api_url = is_pro_server(server) + if server_is_pro: + connect_to_pro_server( + pro_server=server, + api_key=api_key_value, + refresh=refresh, + # Prefer the pro API URL extracted from the server info if + # available + pro_api_url=server_pro_api_url or pro_api_url, + ) + else: + connect_to_server( + url=server, + api_key=api_key_value, + verify_ssl=verify_ssl, + refresh=refresh, + ) elif current_non_local_server: # The server argument is not provided, so we default to # re-authenticating to the current non-local server that the client is # connected to. server = current_non_local_server - if is_zenml_pro_server_url(server): + # First, try to discover if the server is a ZenML Pro server or not + server_is_pro, server_pro_api_url = is_pro_server(server) + + if server_is_pro: cli_utils.declare( "No server argument was provided. Re-authenticating to " "ZenML Pro...\n" @@ -775,6 +839,9 @@ def login( pro_server=server, api_key=api_key_value, refresh=True, + # Prefer the pro API URL extracted from the server info if + # available + pro_api_url=server_pro_api_url or pro_api_url, ) else: cli_utils.declare( @@ -803,6 +870,7 @@ def login( ) connect_to_pro_server( api_key=api_key_value, + pro_api_url=pro_api_url, ) @@ -859,11 +927,19 @@ def login( default=False, type=click.BOOL, ) +@click.option( + "--pro-api-url", + type=str, + default=None, + help="Custom URL for the ZenML Pro API. Useful when disconnecting " + "from a self-hosted ZenML Pro deployment.", +) def logout( server: Optional[str] = None, local: bool = False, clear: bool = False, pro: bool = False, + pro_api_url: Optional[str] = None, ) -> None: """Disconnect from a ZenML server. @@ -872,6 +948,7 @@ def logout( clear: Clear all stored credentials and tokens. local: Disconnect from the local ZenML server. pro: Log out from ZenML Pro. + pro_api_url: Custom URL for the ZenML Pro API. """ from zenml.login.credentials_store import get_credentials_store @@ -887,8 +964,9 @@ def logout( "The `--pro` flag cannot be used with a specific server URL." ) - if credentials_store.has_valid_pro_authentication(): - credentials_store.clear_pro_credentials() + pro_api_url = pro_api_url or ZENML_PRO_API_URL + if credentials_store.has_valid_pro_authentication(pro_api_url): + credentials_store.clear_pro_credentials(pro_api_url) cli_utils.declare("Logged out from ZenML Pro.") else: cli_utils.declare( @@ -896,10 +974,13 @@ def logout( ) if clear: - if is_zenml_pro_server_url(store_cfg.url): + # Try to determine if the client is currently connected to a ZenML + # Pro server with the given pro API URL + credentials = credentials_store.get_credentials(store_cfg.url) + if credentials and credentials.pro_api_url == pro_api_url: gc.set_default_store() - credentials_store.clear_all_pro_tokens() + credentials_store.clear_all_pro_tokens(pro_api_url) cli_utils.declare("Logged out from all ZenML Pro servers.") return @@ -938,10 +1019,11 @@ def logout( assert server is not None - if is_zenml_pro_server_url(server): - gc.set_default_store() - credentials = credentials_store.get_credentials(server) - if credentials and (clear or store_cfg.url == server): + gc.set_default_store() + credentials = credentials_store.get_credentials(server) + + if credentials and (clear or store_cfg.url == server): + if credentials.type == ServerType.PRO: cli_utils.declare( f"Logging out from ZenML Pro server '{credentials.server_name}'." ) @@ -955,14 +1037,6 @@ def logout( "with 'zenml login '." ) else: - cli_utils.declare( - f"The client is not currently connected to the ZenML Pro server " - f"at '{server}'." - ) - else: - gc.set_default_store() - credentials = credentials_store.get_credentials(server) - if credentials and (clear or store_cfg.url == server): cli_utils.declare(f"Logging out from {server}.") if clear: credentials_store.clear_credentials(server_url=server) @@ -972,8 +1046,8 @@ def logout( "to the same server or 'zenml server list' to view other available " "servers that you can connect to with 'zenml login '." ) - else: - cli_utils.declare( - f"The client is not currently connected to the ZenML server at " - f"'{server}'." - ) + else: + cli_utils.declare( + f"The client is not currently connected to the ZenML server at " + f"'{server}'." + ) diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 39c241c377..9123b6f692 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -187,6 +187,7 @@ def status() -> None: """Show details about the current configuration.""" from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient + from zenml.login.pro.constants import ZENML_PRO_API_URL gc = GlobalConfiguration() client = Client() @@ -214,10 +215,11 @@ def status() -> None: if server.type == ServerType.PRO: # If connected to a ZenML Pro server, refresh the server info pro_credentials = credentials_store.get_pro_credentials( - allow_expired=False + pro_api_url=server.pro_api_url or ZENML_PRO_API_URL, + allow_expired=False, ) if pro_credentials: - pro_client = ZenMLProClient() + pro_client = ZenMLProClient(pro_credentials.url) pro_servers = pro_client.tenant.list( url=store_cfg.url, member_only=True ) @@ -551,19 +553,36 @@ def server() -> None: help="Show all ZenML servers, including those that are not running " "and those with an expired authentication.", ) -def server_list(verbose: bool = False, all: bool = False) -> None: +@click.option( + "--pro-api-url", + type=str, + default=None, + help="Custom URL for the ZenML Pro API. Useful when disconnecting " + "from a self-hosted ZenML Pro deployment.", +) +def server_list( + verbose: bool = False, + all: bool = False, + pro_api_url: Optional[str] = None, +) -> None: """List all ZenML servers that this client is authorized to access. Args: verbose: Whether to show verbose output. all: Whether to show all ZenML servers. + pro_api_url: Custom URL for the ZenML Pro API. """ from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient from zenml.login.pro.tenant.models import TenantRead, TenantStatus + from zenml.login.pro.constants import ZENML_PRO_API_URL + + pro_api_url = pro_api_url or ZENML_PRO_API_URL credentials_store = get_credentials_store() - pro_token = credentials_store.get_pro_token(allow_expired=True) + pro_token = credentials_store.get_pro_token( + allow_expired=True, pro_api_url=pro_api_url + ) current_store_config = GlobalConfiguration().store_configuration # The list of ZenML Pro servers kept in the credentials store @@ -583,7 +602,7 @@ def server_list(verbose: bool = False, all: bool = False) -> None: accessible_pro_servers: List[TenantRead] = [] try: - client = ZenMLProClient() + client = ZenMLProClient(pro_api_url) accessible_pro_servers = client.tenant.list(member_only=not all) except AuthorizationException as e: cli_utils.warning(f"ZenML Pro authorization error: {e}") diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 963b83e316..b5d3375014 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -29,6 +29,7 @@ ) from zenml.constants import ( + DEFAULT_REPORTABLE_RESOURCES, DEFAULT_ZENML_JWT_TOKEN_ALGORITHM, DEFAULT_ZENML_JWT_TOKEN_LEEWAY, DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING, @@ -250,7 +251,7 @@ class ServerConfiguration(BaseModel): server_url: Optional[str] = None dashboard_url: Optional[str] = None root_url_path: str = "" - metadata: Dict[str, Any] = {} + metadata: Dict[str, str] = {} auth_scheme: AuthScheme = AuthScheme.OAUTH2_PASSWORD_BEARER jwt_token_algorithm: str = DEFAULT_ZENML_JWT_TOKEN_ALGORITHM jwt_token_issuer: Optional[str] = None @@ -284,6 +285,7 @@ class ServerConfiguration(BaseModel): rbac_implementation_source: Optional[str] = None feature_gate_implementation_source: Optional[str] = None + reportable_resources: List[str] = [] workload_manager_implementation_source: Optional[str] = None pipeline_run_auth_window: int = ( DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW @@ -555,7 +557,66 @@ def get_server_config(cls) -> "ServerConfiguration": k[len(ENV_ZENML_SERVER_PREFIX) :].lower() ] = v - return ServerConfiguration(**env_server_config) + server_config = ServerConfiguration(**env_server_config) + + if server_config.deployment_type == ServerDeploymentType.CLOUD: + # If the zenml server is a Pro server, we will apply the Pro + # configuration overrides to the server config automatically. + # TODO: these should be retrieved dynamically from the ZenML Pro + # API. + server_pro_config = ServerProConfiguration.get_server_config() + server_config.auth_scheme = AuthScheme.EXTERNAL + server_config.external_login_url = ( + f"{server_pro_config.dashboard_url}/api/auth/login" + ) + server_config.external_user_info_url = ( + f"{server_pro_config.api_url}/users/authorize_server" + ) + server_config.external_server_id = server_pro_config.tenant_id + server_config.rbac_implementation_source = ( + "zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC" + ) + server_config.feature_gate_implementation_source = "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface" + server_config.reportable_resources = DEFAULT_REPORTABLE_RESOURCES + server_config.dashboard_url = f"{server_pro_config.dashboard_url}/organizations/{server_pro_config.organization_id}/tenants/{server_pro_config.tenant_id}" + server_config.metadata.update( + dict( + account_id=str(server_pro_config.organization_id), + organization_id=str(server_pro_config.organization_id), + tenant_id=str(server_pro_config.tenant_id), + ) + ) + if server_pro_config.tenant_name: + server_config.metadata.update( + dict(tenant_name=server_pro_config.tenant_name) + ) + + extra_cors_allow_origins = [ + server_pro_config.dashboard_url, + server_pro_config.api_url, + ] + if server_config.server_url: + extra_cors_allow_origins.append(server_config.server_url) + if ( + not server_config.cors_allow_origins + or server_config.cors_allow_origins == ["*"] + ): + server_config.cors_allow_origins = extra_cors_allow_origins + else: + server_config.cors_allow_origins += extra_cors_allow_origins + if "*" in server_config.cors_allow_origins: + # Remove the wildcard from the list + server_config.cors_allow_origins.remove("*") + + # Remove duplicates + server_config.cors_allow_origins = list( + set(server_config.cors_allow_origins) + ) + + if server_config.jwt_token_expire_minutes is None: + server_config.jwt_token_expire_minutes = 60 + + return server_config model_config = ConfigDict( # Allow extra attributes from configs of previous ZenML versions to @@ -573,16 +634,26 @@ class ServerProConfiguration(BaseModel): Attributes: api_url: The ZenML Pro API URL. + dashboard_url: The ZenML Pro dashboard URL. oauth2_client_secret: The ZenML Pro OAuth2 client secret used to authenticate the ZenML server with the ZenML Pro API. oauth2_audience: The OAuth2 audience. + organization_id: The ZenML Pro organization ID. + organization_name: The ZenML Pro organization name. + tenant_id: The ZenML Pro tenant ID. + tenant_name: The ZenML Pro tenant name. """ api_url: str + dashboard_url: str oauth2_client_secret: str oauth2_audience: str + organization_id: UUID + organization_name: Optional[str] = None + tenant_id: UUID + tenant_name: Optional[str] = None - @field_validator("api_url") + @field_validator("api_url", "dashboard_url") @classmethod def _strip_trailing_slashes_url(cls, url: str) -> str: """Strip any trailing slashes on the API URL. diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 21b50425b8..593767f32d 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -179,9 +179,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_" ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE" ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME" -ENV_ZENML_SERVER_REPORTABLE_RESOURCES = ( - f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES" -) ENV_ZENML_SERVER_AUTO_ACTIVATE = f"{ENV_ZENML_SERVER_PREFIX}AUTO_ACTIVATE" ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK = ( "ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK" @@ -322,14 +319,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS = 30 DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024 -# Configurations to decide which resources report their usage and check for -# entitlement in the case of a cloud deployment. Expected Format is this: -# ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]' -REPORTABLE_RESOURCES: List[str] = handle_json_env_var( - ENV_ZENML_SERVER_REPORTABLE_RESOURCES, - expected_type=list, - default=["pipeline", "pipeline_run", "model"], -) +DEFAULT_REPORTABLE_RESOURCES = ["pipeline", "pipeline_run", "model"] REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline", "pipeline_run"] # API Endpoint paths: diff --git a/src/zenml/login/credentials.py b/src/zenml/login/credentials.py index 529c883743..8aa218bf89 100644 --- a/src/zenml/login/credentials.py +++ b/src/zenml/login/credentials.py @@ -23,6 +23,7 @@ from zenml.login.pro.constants import ZENML_PRO_API_URL, ZENML_PRO_URL from zenml.login.pro.tenant.models import TenantRead, TenantStatus from zenml.models import ServerModel +from zenml.models.v2.misc.server_models import ServerDeploymentType from zenml.services.service_status import ServiceState from zenml.utils.enum_utils import StrEnum from zenml.utils.string_utils import get_human_readable_time @@ -88,13 +89,20 @@ class ServerCredentials(BaseModel): password: Optional[str] = None # Extra server attributes + deployment_type: Optional[ServerDeploymentType] = None server_id: Optional[UUID] = None server_name: Optional[str] = None - organization_name: Optional[str] = None - organization_id: Optional[UUID] = None status: Optional[str] = None version: Optional[str] = None + # Pro server attributes + organization_name: Optional[str] = None + organization_id: Optional[UUID] = None + tenant_name: Optional[str] = None + tenant_id: Optional[UUID] = None + pro_api_url: Optional[str] = None + pro_dashboard_url: Optional[str] = None + @property def id(self) -> str: """Get the server identifier. @@ -113,11 +121,13 @@ def type(self) -> ServerType: Returns: The server type. """ - from zenml.login.pro.utils import is_zenml_pro_server_url - + if self.deployment_type == ServerDeploymentType.CLOUD: + return ServerType.PRO if self.url == ZENML_PRO_API_URL: return ServerType.PRO_API - if self.organization_id or is_zenml_pro_server_url(self.url): + if self.url == self.pro_api_url: + return ServerType.PRO + if self.organization_id or self.tenant_id: return ServerType.PRO if urlparse(self.url).hostname in [ "localhost", @@ -138,25 +148,39 @@ def update_server_info( if isinstance(server_info, ServerModel): # The server ID doesn't change during the lifetime of the server self.server_id = self.server_id or server_info.id - # All other attributes can change during the lifetime of the server + self.deployment_type = server_info.deployment_type server_name = ( - server_info.metadata.get("tenant_name") or server_info.name + server_info.pro_tenant_name + or server_info.metadata.get("tenant_name") + or server_info.name ) if server_name: self.server_name = server_name - organization_id = server_info.metadata.get("organization_id") - if organization_id: - self.organization_id = UUID(organization_id) + if server_info.pro_organization_id: + self.organization_id = server_info.pro_organization_id + if server_info.pro_tenant_id: + self.server_id = server_info.pro_tenant_id + if server_info.pro_organization_name: + self.organization_name = server_info.pro_organization_name + if server_info.pro_tenant_name: + self.tenant_name = server_info.pro_tenant_name + if server_info.pro_api_url: + self.pro_api_url = server_info.pro_api_url + if server_info.pro_dashboard_url: + self.pro_dashboard_url = server_info.pro_dashboard_url self.version = server_info.version or self.version # The server information was retrieved from the server itself, so we # can assume that the server is available self.status = "available" else: + self.deployment_type = ServerDeploymentType.CLOUD self.server_id = server_info.id self.server_name = server_info.name self.organization_name = server_info.organization_name self.organization_id = server_info.organization_id + self.tenant_name = server_info.name + self.tenant_id = server_info.id self.status = server_info.status self.version = server_info.version @@ -247,9 +271,10 @@ def dashboard_url(self) -> str: """ if self.organization_id and self.server_id: return ( - ZENML_PRO_URL + (self.pro_dashboard_url or ZENML_PRO_URL) + f"/organizations/{str(self.organization_id)}/tenants/{str(self.server_id)}" ) + return self.url @property @@ -261,8 +286,8 @@ def dashboard_organization_url(self) -> str: """ if self.organization_id: return ( - ZENML_PRO_URL + f"/organizations/{str(self.organization_id)}" - ) + self.pro_dashboard_url or ZENML_PRO_URL + ) + f"/organizations/{str(self.organization_id)}" return "" @property diff --git a/src/zenml/login/credentials_store.py b/src/zenml/login/credentials_store.py index 2287dae498..b1f0bbb146 100644 --- a/src/zenml/login/credentials_store.py +++ b/src/zenml/login/credentials_store.py @@ -289,10 +289,13 @@ def get_credentials(self, server_url: str) -> Optional[ServerCredentials]: self.check_and_reload_from_file() return self.credentials.get(server_url) - def get_pro_token(self, allow_expired: bool = False) -> Optional[APIToken]: - """Retrieve a valid token from the credentials store for the ZenML Pro API server. + def get_pro_token( + self, pro_api_url: str, allow_expired: bool = False + ) -> Optional[APIToken]: + """Retrieve a valid token from the credentials store for a ZenML Pro API server. Args: + pro_api_url: The URL of the ZenML Pro API server. allow_expired: Whether to allow expired tokens to be returned. The default behavior is to return None if a token does exist but is expired. @@ -300,14 +303,18 @@ def get_pro_token(self, allow_expired: bool = False) -> Optional[APIToken]: Returns: The stored token if it exists and is not expired, None otherwise. """ - return self.get_token(ZENML_PRO_API_URL, allow_expired) + credential = self.get_pro_credentials(pro_api_url, allow_expired) + if credential: + return credential.api_token + return None def get_pro_credentials( - self, allow_expired: bool = False + self, pro_api_url: str, allow_expired: bool = False ) -> Optional[ServerCredentials]: - """Retrieve a valid token from the credentials store for the ZenML Pro API server. + """Retrieve a valid token from the credentials store for a ZenML Pro API server. Args: + pro_api_url: The URL of the ZenML Pro API server. allow_expired: Whether to allow expired tokens to be returned. The default behavior is to return None if a token does exist but is expired. @@ -315,26 +322,47 @@ def get_pro_credentials( Returns: The stored credentials if they exist and are not expired, None otherwise. """ - credential = self.get_credentials(ZENML_PRO_API_URL) + credential = self.get_credentials(pro_api_url) if ( credential + and credential.type == ServerType.PRO_API and credential.api_token and (not credential.api_token.expired or allow_expired) ): return credential return None - def clear_pro_credentials(self) -> None: - """Delete the token from the store for the ZenML Pro API server.""" - self.clear_token(ZENML_PRO_API_URL) + def clear_pro_credentials(self, pro_api_url: str) -> None: + """Delete the token from the store for a ZenML Pro API server. + + Args: + pro_api_url: The URL of the ZenML Pro API server. + """ + self.clear_token(pro_api_url) + + def clear_all_pro_tokens( + self, pro_api_url: str + ) -> List[ServerCredentials]: + """Delete all tokens from the store for ZenML Pro servers connected to a given API server. + + Args: + pro_api_url: The URL of the ZenML Pro API server. - def clear_all_pro_tokens(self) -> None: - """Delete all tokens from the store for ZenML Pro API servers.""" + Returns: + A list of the credentials that were cleared. + """ + credentials_to_clear = [] for server_url, server in self.credentials.copy().items(): - if server.type == ServerType.PRO: + if ( + server.type == ServerType.PRO + and server.pro_api_url + and server.pro_api_url == pro_api_url + ): if server.api_key: continue self.clear_token(server_url) + credentials_to_clear.append(server_url) + return credentials_to_clear def has_valid_authentication(self, url: str) -> bool: """Check if a valid authentication credential for the given server URL is stored. @@ -357,13 +385,16 @@ def has_valid_authentication(self, url: str) -> bool: token = credential.api_token return token is not None and not token.expired - def has_valid_pro_authentication(self) -> bool: - """Check if a valid token for the ZenML Pro API server is stored. + def has_valid_pro_authentication(self, pro_api_url: str) -> bool: + """Check if a valid token for a ZenML Pro API server is stored. + + Args: + pro_api_url: The URL of the ZenML Pro API server. Returns: bool: True if a valid token is stored, False otherwise. """ - return self.get_token(ZENML_PRO_API_URL) is not None + return self.get_pro_token(pro_api_url) is not None def set_api_key( self, @@ -433,12 +464,14 @@ def set_token( self, server_url: str, token_response: OAuthTokenResponse, + is_zenml_pro: bool = False, ) -> APIToken: """Store an API token received from an OAuth2 server. Args: server_url: The server URL for which the token is to be stored. token_response: Token response received from an OAuth2 server. + is_zenml_pro: Whether the token is for a ZenML Pro server. Returns: APIToken: The stored token. @@ -476,10 +509,14 @@ def set_token( if credential: credential.api_token = api_token else: - self.credentials[server_url] = ServerCredentials( + credential = self.credentials[server_url] = ServerCredentials( url=server_url, api_token=api_token ) + if is_zenml_pro: + # This is how we encode that the token is for a ZenML Pro server + credential.pro_api_url = server_url + self._save_credentials() return api_token diff --git a/src/zenml/login/pro/client.py b/src/zenml/login/pro/client.py index 516f84833c..de72f50f2c 100644 --- a/src/zenml/login/pro/client.py +++ b/src/zenml/login/pro/client.py @@ -61,13 +61,12 @@ class ZenMLProClient(metaclass=SingletonMetaClass): _organization: Optional["OrganizationClient"] = None def __init__( - self, url: Optional[str] = None, api_token: Optional[APIToken] = None + self, url: str, api_token: Optional[APIToken] = None ) -> None: """Initialize the ZenML Pro client. Args: - url: The URL of the ZenML Pro API server. If not provided, the - default ZenML Pro API server URL is used. + url: The URL of the ZenML Pro API server. api_token: The API token to use for authentication. If not provided, the token is fetched from the credentials store. @@ -75,7 +74,7 @@ def __init__( AuthorizationException: If no API token is provided and no token is found in the credentials store. """ - self._url = url or ZENML_PRO_API_URL + self._url = url if api_token is None: logger.debug( "No ZenML Pro API token provided. Fetching from credentials " diff --git a/src/zenml/login/pro/constants.py b/src/zenml/login/pro/constants.py index 4441367d91..e9ea60abf1 100644 --- a/src/zenml/login/pro/constants.py +++ b/src/zenml/login/pro/constants.py @@ -26,9 +26,3 @@ DEFAULT_ZENML_PRO_URL = "https://cloud.zenml.io" ZENML_PRO_URL = os.getenv(ENV_ZENML_PRO_URL, default=DEFAULT_ZENML_PRO_URL) - -ENV_ZENML_PRO_SERVER_SUBDOMAIN = "ZENML_PRO_SERVER_SUBDOMAIN" -DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN = "cloudinfra.zenml.io" -ZENML_PRO_SERVER_SUBDOMAIN = os.getenv( - ENV_ZENML_PRO_SERVER_SUBDOMAIN, default=DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN -) diff --git a/src/zenml/login/pro/utils.py b/src/zenml/login/pro/utils.py index 256ced18b6..b58941c6e4 100644 --- a/src/zenml/login/pro/utils.py +++ b/src/zenml/login/pro/utils.py @@ -13,37 +13,16 @@ # permissions and limitations under the License. """ZenML Pro login utils.""" -import re - from zenml.logger import get_logger +from zenml.login.credentials import ServerType from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient -from zenml.login.pro.constants import ZENML_PRO_SERVER_SUBDOMAIN +from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.tenant.models import TenantStatus logger = get_logger(__name__) -def is_zenml_pro_server_url(url: str) -> bool: - """Check if a given URL is a ZenML Pro server. - - Args: - url: URL to check - - Returns: - True if the URL is a ZenML Pro tenant, False otherwise - """ - domain_regex = ZENML_PRO_SERVER_SUBDOMAIN.replace(".", r"\.") - return bool( - re.match( - r"^(https://)?[a-zA-Z0-9-\.]+\.{domain}/?$".format( - domain=domain_regex - ), - url, - ) - ) - - def get_troubleshooting_instructions(url: str) -> str: """Get troubleshooting instructions for a given ZenML Pro server URL. @@ -54,8 +33,15 @@ def get_troubleshooting_instructions(url: str) -> str: Troubleshooting instructions """ credentials_store = get_credentials_store() - if credentials_store.has_valid_pro_authentication(): - client = ZenMLProClient() + + credentials = credentials_store.get_credentials(url) + if credentials and credentials.type == ServerType.PRO: + pro_api_url = credentials.pro_api_url or ZENML_PRO_API_URL + + if pro_api_url and credentials_store.has_valid_pro_authentication( + pro_api_url + ): + client = ZenMLProClient(pro_api_url) try: servers = client.tenant.list(url=url, member_only=False) diff --git a/src/zenml/login/server_info.py b/src/zenml/login/server_info.py new file mode 100644 index 0000000000..ba4a240e0c --- /dev/null +++ b/src/zenml/login/server_info.py @@ -0,0 +1,52 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML server information retrieval.""" + +from typing import Optional + +from zenml.logger import get_logger +from zenml.models import ServerModel +from zenml.zen_stores.rest_zen_store import ( + RestZenStore, + RestZenStoreConfiguration, +) + +logger = get_logger(__name__) + + +def get_server_info(url: str) -> Optional[ServerModel]: + """Retrieve server information from a remote ZenML server. + + Args: + url: The URL of the ZenML server. + + Returns: + The server information or None if the server info could not be fetched. + """ + # Here we try to leverage the existing RestZenStore support to fetch the + # server info and only the server info, which doesn't actually need + # any authentication. + try: + store = RestZenStore( + config=RestZenStoreConfiguration( + url=url, + ) + ) + return store.server_info + except Exception as e: + logger.warning( + f"Failed to fetch server info from the server running at {url}: {e}" + ) + + return None diff --git a/src/zenml/login/web_login.py b/src/zenml/login/web_login.py index 740c4d7552..66f52ee541 100644 --- a/src/zenml/login/web_login.py +++ b/src/zenml/login/web_login.py @@ -34,13 +34,14 @@ from zenml.logger import get_logger from zenml.login.credentials import APIToken from zenml.login.pro.constants import ZENML_PRO_API_URL -from zenml.login.pro.utils import is_zenml_pro_server_url logger = get_logger(__name__) def web_login( - url: Optional[str] = None, verify_ssl: Optional[Union[str, bool]] = None + url: Optional[str] = None, + verify_ssl: Optional[Union[str, bool]] = None, + pro_api_url: Optional[str] = None, ) -> APIToken: """Implements the OAuth2 Device Authorization Grant flow. @@ -61,6 +62,8 @@ def web_login( verify_ssl: Whether to verify the SSL certificate of the OAuth2 server. If a string is passed, it is interpreted as the path to a CA bundle file. + pro_api_url: The URL of the ZenML Pro API server. If not provided, the + default ZenML Pro API server URL is used. Returns: The response returned by the OAuth2 server. @@ -103,16 +106,16 @@ def web_login( if not url: # If no URL is provided, we use the ZenML Pro API server by default zenml_pro = True - url = base_url = ZENML_PRO_API_URL + url = base_url = pro_api_url or ZENML_PRO_API_URL else: # Get rid of any trailing slashes to prevent issues when having double # slashes in the URL url = url.rstrip("/") - if is_zenml_pro_server_url(url): + if pro_api_url: # This is a ZenML Pro server. The device authentication is done # through the ZenML Pro API. zenml_pro = True - base_url = ZENML_PRO_API_URL + base_url = pro_api_url else: base_url = url @@ -240,4 +243,6 @@ def web_login( ) # Save the token in the credentials store - return credentials_store.set_token(url, token_response) + return credentials_store.set_token( + url, token_response, is_zenml_pro=zenml_pro + ) diff --git a/src/zenml/models/v2/misc/server_models.py b/src/zenml/models/v2/misc/server_models.py index d7e619020a..6b029d4d4e 100644 --- a/src/zenml/models/v2/misc/server_models.py +++ b/src/zenml/models/v2/misc/server_models.py @@ -107,6 +107,42 @@ class ServerModel(BaseModel): title="Timestamp of latest user activity traced on the server.", ) + pro_dashboard_url: Optional[str] = Field( + None, + title="The base URL of the ZenML Pro dashboard to which the server " + "is connected. Only set if the server is a ZenML Pro server.", + ) + + pro_api_url: Optional[str] = Field( + None, + title="The base URL of the ZenML Pro API to which the server is " + "connected. Only set if the server is a ZenML Pro server.", + ) + + pro_organization_id: Optional[UUID] = Field( + None, + title="The ID of the ZenML Pro organization to which the server is " + "connected. Only set if the server is a ZenML Pro server.", + ) + + pro_organization_name: Optional[str] = Field( + None, + title="The name of the ZenML Pro organization to which the server is " + "connected. Only set if the server is a ZenML Pro server.", + ) + + pro_tenant_id: Optional[UUID] = Field( + None, + title="The ID of the ZenML Pro tenant to which the server is connected. " + "Only set if the server is a ZenML Pro server.", + ) + + pro_tenant_name: Optional[str] = Field( + None, + title="The name of the ZenML Pro tenant to which the server is connected. " + "Only set if the server is a ZenML Pro server.", + ) + def is_local(self) -> bool: """Return whether the server is running locally. diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index b604636146..c7b5a1db95 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -7,7 +7,6 @@ from requests.adapters import HTTPAdapter, Retry from zenml.config.server_config import ServerProConfiguration -from zenml.exceptions import SubscriptionUpgradeRequiredError from zenml.zen_server.utils import get_zenml_headers, server_config _cloud_connection: Optional["ZenMLCloudConnection"] = None @@ -23,45 +22,70 @@ def __init__(self) -> None: self._token: Optional[str] = None self._token_expires_at: Optional[datetime] = None - def get( - self, endpoint: str, params: Optional[Dict[str, Any]] + def request( + self, + method: str, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, ) -> requests.Response: - """Send a GET request using the active session. + """Send a request using the active session. Args: + method: The HTTP method to use. endpoint: The endpoint to send the request to. This will be appended to the base URL. params: Parameters to include in the request. + data: Data to include in the request. Raises: RuntimeError: If the request failed. - SubscriptionUpgradeRequiredError: In case the current subscription - tier is insufficient for the attempted operation. Returns: The response. """ url = self._config.api_url + endpoint - response = self.session.get(url=url, params=params, timeout=7) + response = self.session.request( + method=method, url=url, params=params, json=data, timeout=7 + ) if response.status_code == 401: - # If we get an Unauthorized error from the API serer, we refresh the - # auth token and try again + # Refresh the auth token and try again self._clear_session() - response = self.session.get(url=url, params=params, timeout=7) + response = self.session.request( + method=method, url=url, params=params, json=data, timeout=7 + ) try: response.raise_for_status() - except requests.HTTPError: - if response.status_code == 402: - raise SubscriptionUpgradeRequiredError(response.json()) - else: - raise RuntimeError( - f"Failed with the following error {response} {response.text}" - ) + except requests.HTTPError as e: + raise RuntimeError( + f"Failed while trying to contact the central zenml pro " + f"service: {e}" + ) return response + def get( + self, endpoint: str, params: Optional[Dict[str, Any]] + ) -> requests.Response: + """Send a GET request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + + Raises: + RuntimeError: If the request failed. + SubscriptionUpgradeRequiredError: In case the current subscription + tier is insufficient for the attempted operation. + + Returns: + The response. + """ + return self.request(method="GET", endpoint=endpoint, params=params) + def post( self, endpoint: str, @@ -82,27 +106,9 @@ def post( Returns: The response. """ - url = self._config.api_url + endpoint - - response = self.session.post( - url=url, params=params, json=data, timeout=7 + return self.request( + method="POST", endpoint=endpoint, params=params, data=data ) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.post( - url=url, params=params, json=data, timeout=7 - ) - - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact the central zenml pro " - f"service: {e}" - ) - - return response def patch( self, @@ -124,27 +130,9 @@ def patch( Returns: The response. """ - url = self._config.api_url + endpoint - - response = self.session.post( - url=url, params=params, json=data, timeout=7 + return self.request( + method="PATCH", endpoint=endpoint, params=params, data=data ) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.patch( - url=url, params=params, json=data, timeout=7 - ) - - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact the central zenml pro " - f"service: {e}" - ) - - return response @property def session(self) -> requests.Session: @@ -267,4 +255,4 @@ def cloud_connection() -> ZenMLCloudConnection: def send_pro_tenant_status_update() -> None: """Send a tenant status update to the Cloud API.""" - cloud_connection().patch("/tenants/status_updates") + cloud_connection().patch("/tenant_status") diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index b0f724bceb..211f2f3b16 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -142,21 +142,24 @@ Returns: {{- define "zenml.serverConfigurationAttrs" -}} {{- if .ZenML.pro.enabled }} - -auth_scheme: EXTERNAL deployment_type: cloud -cors_allow_origins: "{{ .ZenML.pro.dashboardURL }},{{ .ZenML.pro.serverURL }}" -external_login_url: "{{ .ZenML.pro.dashboardURL }}/api/auth/login" -external_user_info_url: "{{ .ZenML.pro.dashboardURL }}/users/authorize_server" -external_server_id: {{ .ZenML.pro.tenantID | quote }} -jwt_token_expire_minutes: "60" -rbac_implementation_source: "zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC" -feature_gate_implementation_source: "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface" -dashboard_url: "{{ .ZenML.pro.dashboardURL }}/organizations/{{ .ZenML.pro.organizationID }}/tenants/{{ .ZenML.pro.tenantID }}" -metadata: '{"account_id":"{{ .ZenML.pro.organizationID }}","organization_id": "{{ .ZenML.pro.organizationID }}","tenant_id":"{{ .ZenML.pro.tenantID }}"}' -reportable_resources: '["pipeline","pipeline_run","model"]' pro_api_url: "{{ .ZenML.pro.apiURL }}" +pro_dashboard_url: "{{ .ZenML.pro.dashboardURL }}" pro_oauth2_audience: "{{ .ZenML.pro.apiURL }}" +pro_organization_id: "{{ .ZenML.pro.organizationID }}" +pro_tenant_id: "{{ .ZenML.pro.tenantID }}" +{{- if .ZenML.pro.tenantName }} +pro_tenant_name: "{{ .ZenML.pro.tenantName }}" +{{- end }} +{{- if .ZenML.pro.organizationName }} +pro_organization_name: "{{ .ZenML.pro.organizationName }}" +{{- end }} +{{- if .ZenML.pro.extraCorsOrigins }} +cors_allow_origins: "{{ join "," .ZenML.pro.extraCorsOrigins }}" +{{- end }} +{{- if .ZenML.auth.jwtTokenExpireMinutes }} +jwt_token_expire_minutes: {{ .ZenML.auth.jwtTokenExpireMinutes | quote }} +{{- end }} {{- else }} diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index bea454d7ad..df4401fe2e 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -35,18 +35,27 @@ zenml: # `pro` section. enabled: false - # The URL where the ZenML Pro server API is reachable. + # The URL where the ZenML Pro server API is reachable apiURL: https://cloudapi.zenml.io # The URL where the ZenML Pro dashboard is reachable. dashboardURL: https://cloud.zenml.io + # Additional origins to allow in the CORS policy. + extraCorsOrigins: + # The ID of the ZenML Pro tenant to use. tenantID: + # The name of the ZenML Pro tenant to use. + tenantName: + # The ID of the ZenML Pro organization to use. organizationID: + # The name of the ZenML Pro organization to use. + organizationName: + # The enrollment key to use for the ZenML Pro tenant. enrollmentKey: diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index de2cb958ed..38b8b00499 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -19,7 +19,6 @@ from pydantic import BaseModel from zenml.constants import ( - REPORTABLE_RESOURCES, REQUIRES_CUSTOM_RESOURCE_REPORTING, ) from zenml.exceptions import IllegalOperationError @@ -43,6 +42,7 @@ verify_permission, verify_permission_for_model, ) +from zenml.zen_server.utils import server_config AnyRequest = TypeVar("AnyRequest", bound=BaseRequest) AnyResponse = TypeVar("AnyResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg] @@ -82,7 +82,7 @@ def verify_permissions_and_create_entity( verify_permission(resource_type=resource_type, action=Action.CREATE) needs_usage_increment = ( - resource_type in REPORTABLE_RESOURCES + resource_type in server_config().reportable_resources and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING ) if needs_usage_increment: @@ -129,7 +129,7 @@ def verify_permissions_and_batch_create_entity( verify_permission(resource_type=resource_type, action=Action.CREATE) - if resource_type in REPORTABLE_RESOURCES: + if resource_type in server_config().reportable_resources: raise RuntimeError( "Batch requests are currently not possible with usage-tracked features." ) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 124660e5cf..7c2341dffe 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -22,7 +22,6 @@ API, MODEL_VERSIONS, MODELS, - REPORTABLE_RESOURCES, VERSION_1, ) from zenml.models import ( @@ -170,7 +169,7 @@ def delete_model( ) if server_config().feature_gate_enabled: - if ResourceType.MODEL in REPORTABLE_RESOURCES: + if ResourceType.MODEL in server_config().reportable_resources: report_decrement(ResourceType.MODEL, resource_id=model.id) diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index 837c7738de..43a0507ecb 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -20,7 +20,6 @@ from zenml.constants import ( API, PIPELINES, - REPORTABLE_RESOURCES, RUNS, VERSION_1, ) @@ -45,6 +44,7 @@ from zenml.zen_server.utils import ( handle_exceptions, make_dependable, + server_config, zen_store, ) @@ -159,7 +159,7 @@ def delete_pipeline( ) should_decrement = ( - ResourceType.PIPELINE in REPORTABLE_RESOURCES + ResourceType.PIPELINE in server_config().reportable_resources and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) == 0 ) diff --git a/src/zenml/zen_server/routers/server_endpoints.py b/src/zenml/zen_server/routers/server_endpoints.py index a9895c6cc6..5f38e23576 100644 --- a/src/zenml/zen_server/routers/server_endpoints.py +++ b/src/zenml/zen_server/routers/server_endpoints.py @@ -31,6 +31,7 @@ from zenml.exceptions import IllegalOperationError from zenml.models import ( ServerActivationRequest, + ServerDeploymentType, ServerLoadInfo, ServerModel, ServerSettingsResponse, diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index dd26289d05..32af09d2a4 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -27,7 +27,6 @@ PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, - REPORTABLE_RESOURCES, RUN_METADATA, RUN_TEMPLATES, RUNS, @@ -112,6 +111,7 @@ from zenml.zen_server.utils import ( handle_exceptions, make_dependable, + server_config, zen_store, ) @@ -524,7 +524,7 @@ def create_pipeline( # We limit pipeline namespaces, not pipeline versions needs_usage_increment = ( - ResourceType.PIPELINE in REPORTABLE_RESOURCES + ResourceType.PIPELINE in server_config().reportable_resources and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) == 0 ) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 11467c4481..99f9e8baac 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -53,6 +53,7 @@ UserFilter, UserResponse, WorkspaceResponse, + ServerDeploymentType, ) from zenml.utils.pydantic_utils import before_validator_handler from zenml.zen_stores.zen_store_interface import ZenStoreInterface @@ -390,7 +391,7 @@ def get_store_info(self) -> ServerModel: secrets_store_type = SecretsStoreType.NONE if isinstance(self, SqlZenStore) and self.config.secrets_store: secrets_store_type = self.config.secrets_store.type - return ServerModel( + store_info = ServerModel( id=GlobalConfiguration().user_id, active=True, version=zenml.__version__, @@ -405,6 +406,23 @@ def get_store_info(self) -> ServerModel: metadata=metadata, ) + # Add ZenML Pro specific store information to the server model, if available. + if store_info.deployment_type == ServerDeploymentType.CLOUD: + from zenml.config.server_config import ServerProConfiguration + + pro_config = ServerProConfiguration.get_server_config() + + store_info.pro_api_url = pro_config.api_url + store_info.pro_dashboard_url = pro_config.dashboard_url + store_info.pro_organization_id = pro_config.organization_id + store_info.pro_tenant_id = pro_config.tenant_id + if pro_config.tenant_name: + store_info.pro_tenant_name = pro_config.tenant_name + if pro_config.organization_name: + store_info.pro_organization_name = pro_config.organization_name + + return store_info + def is_local_store(self) -> bool: """Check if the store is local or connected to a local ZenML server. diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 74f6961e9e..c1d424e8bf 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -124,9 +124,9 @@ from zenml.logger import get_logger from zenml.login.credentials import APIToken from zenml.login.credentials_store import get_credentials_store +from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.utils import ( get_troubleshooting_instructions, - is_zenml_pro_server_url, ) from zenml.models import ( ActionFilter, @@ -501,7 +501,7 @@ def _initialize(self) -> None: except Exception as e: zenml_pro_extra = "" - if is_zenml_pro_server_url(self.url): + if ".zenml.io" in self.url: zenml_pro_extra = ( "\nHINT: " + get_troubleshooting_instructions(self.url) ) @@ -4090,7 +4090,17 @@ def get_or_generate_api_token(self) -> str: # regular ZenML server access token. # Get the ZenML Pro API session token, if cached and valid - pro_token = credentials_store.get_pro_token(allow_expired=True) + + # We need to determine the right ZenML Pro API URL to use + pro_api_url = self.server_info.pro_api_url + if not pro_api_url and credentials and credentials.pro_api_url: + pro_api_url = credentials.pro_api_url + if not pro_api_url: + pro_api_url = ZENML_PRO_API_URL + + pro_token = credentials_store.get_pro_token( + pro_api_url, allow_expired=True + ) if not pro_token: raise CredentialsNotValid( "You need to be logged in to ZenML Pro in order to " From 200baaa6b6833ba0b101e25a3dc8b4359bc5bcdc Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 20 Dec 2024 22:33:37 +0100 Subject: [PATCH 09/13] Fix zenml Pro API credentials store support --- src/zenml/login/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/login/credentials.py b/src/zenml/login/credentials.py index 8aa218bf89..ca76eddafb 100644 --- a/src/zenml/login/credentials.py +++ b/src/zenml/login/credentials.py @@ -126,7 +126,7 @@ def type(self) -> ServerType: if self.url == ZENML_PRO_API_URL: return ServerType.PRO_API if self.url == self.pro_api_url: - return ServerType.PRO + return ServerType.PRO_API if self.organization_id or self.tenant_id: return ServerType.PRO if urlparse(self.url).hostname in [ From 175ce976395b20d6c160b19a4bd1c85a4fd4133f Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 27 Dec 2024 15:34:47 +0100 Subject: [PATCH 10/13] Fix linter errors --- src/zenml/cli/server.py | 2 +- src/zenml/login/credentials_store.py | 3 +-- src/zenml/login/pro/client.py | 5 +--- src/zenml/login/pro/tenant/models.py | 6 +++-- src/zenml/zen_server/cloud_utils.py | 25 ++++++++----------- .../zen_server/routers/auth_endpoints.py | 3 +++ .../zen_server/routers/server_endpoints.py | 1 - src/zenml/zen_stores/base_zen_store.py | 2 +- 8 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 9123b6f692..c39167e44c 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -574,8 +574,8 @@ def server_list( """ from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient - from zenml.login.pro.tenant.models import TenantRead, TenantStatus from zenml.login.pro.constants import ZENML_PRO_API_URL + from zenml.login.pro.tenant.models import TenantRead, TenantStatus pro_api_url = pro_api_url or ZENML_PRO_API_URL diff --git a/src/zenml/login/credentials_store.py b/src/zenml/login/credentials_store.py index b1f0bbb146..efb9c0f895 100644 --- a/src/zenml/login/credentials_store.py +++ b/src/zenml/login/credentials_store.py @@ -25,7 +25,6 @@ from zenml.io import fileio from zenml.logger import get_logger from zenml.login.credentials import APIToken, ServerCredentials, ServerType -from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.tenant.models import TenantRead from zenml.models import OAuthTokenResponse, ServerModel from zenml.utils import yaml_utils @@ -361,7 +360,7 @@ def clear_all_pro_tokens( if server.api_key: continue self.clear_token(server_url) - credentials_to_clear.append(server_url) + credentials_to_clear.append(server) return credentials_to_clear def has_valid_authentication(self, url: str) -> bool: diff --git a/src/zenml/login/pro/client.py b/src/zenml/login/pro/client.py index de72f50f2c..cbc6925d5b 100644 --- a/src/zenml/login/pro/client.py +++ b/src/zenml/login/pro/client.py @@ -33,7 +33,6 @@ from zenml.logger import get_logger from zenml.login.credentials import APIToken from zenml.login.credentials_store import get_credentials_store -from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.models import BaseRestAPIModel from zenml.utils.singleton import SingletonMetaClass from zenml.zen_server.exceptions import exception_from_response @@ -60,9 +59,7 @@ class ZenMLProClient(metaclass=SingletonMetaClass): _tenant: Optional["TenantClient"] = None _organization: Optional["OrganizationClient"] = None - def __init__( - self, url: str, api_token: Optional[APIToken] = None - ) -> None: + def __init__(self, url: str, api_token: Optional[APIToken] = None) -> None: """Initialize the ZenML Pro client. Args: diff --git a/src/zenml/login/pro/tenant/models.py b/src/zenml/login/pro/tenant/models.py index f5efc57a51..166d7f6af1 100644 --- a/src/zenml/login/pro/tenant/models.py +++ b/src/zenml/login/pro/tenant/models.py @@ -76,7 +76,7 @@ class ZenMLServiceStatus(BaseRestAPIModel): class ZenMLServiceRead(BaseRestAPIModel): """Pydantic Model for viewing a ZenML service.""" - configuration: ZenMLServiceConfiguration = Field( + configuration: Optional[ZenMLServiceConfiguration] = Field( description="The service configuration." ) @@ -133,7 +133,9 @@ def version(self) -> Optional[str]: Returns: The ZenML service version. """ - version = self.zenml_service.configuration.version + version = None + if self.zenml_service.configuration: + version = self.zenml_service.configuration.version if self.zenml_service.status and self.zenml_service.status.version: version = self.zenml_service.status.version diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index c7b5a1db95..0aff913563 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -7,6 +7,7 @@ from requests.adapters import HTTPAdapter, Retry from zenml.config.server_config import ServerProConfiguration +from zenml.exceptions import SubscriptionUpgradeRequiredError from zenml.zen_server.utils import get_zenml_headers, server_config _cloud_connection: Optional["ZenMLCloudConnection"] = None @@ -39,6 +40,8 @@ def request( data: Data to include in the request. Raises: + SubscriptionUpgradeRequiredError: If the current subscription tier + is insufficient for the attempted operation. RuntimeError: If the request failed. Returns: @@ -59,10 +62,13 @@ def request( try: response.raise_for_status() except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact the central zenml pro " - f"service: {e}" - ) + if response.status_code == 402: + raise SubscriptionUpgradeRequiredError(response.json()) + else: + raise RuntimeError( + f"Failed while trying to contact the central zenml pro " + f"service: {e}" + ) return response @@ -76,11 +82,6 @@ def get( to the base URL. params: Parameters to include in the request. - Raises: - RuntimeError: If the request failed. - SubscriptionUpgradeRequiredError: In case the current subscription - tier is insufficient for the attempted operation. - Returns: The response. """ @@ -100,9 +101,6 @@ def post( params: Parameters to include in the request. data: Data to include in the request. - Raises: - RuntimeError: If the request failed. - Returns: The response. """ @@ -124,9 +122,6 @@ def patch( params: Parameters to include in the request. data: Data to include in the request. - Raises: - RuntimeError: If the request failed. - Returns: The response. """ diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index f837c8e000..6626cfec5c 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -360,6 +360,9 @@ def device_authorization( Returns: The device authorization response. + + Raises: + HTTPException: If the device authorization is not supported. """ config = server_config() diff --git a/src/zenml/zen_server/routers/server_endpoints.py b/src/zenml/zen_server/routers/server_endpoints.py index 5f38e23576..a9895c6cc6 100644 --- a/src/zenml/zen_server/routers/server_endpoints.py +++ b/src/zenml/zen_server/routers/server_endpoints.py @@ -31,7 +31,6 @@ from zenml.exceptions import IllegalOperationError from zenml.models import ( ServerActivationRequest, - ServerDeploymentType, ServerLoadInfo, ServerModel, ServerSettingsResponse, diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 99f9e8baac..42c812800f 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -47,13 +47,13 @@ from zenml.logger import get_logger from zenml.models import ( ServerDatabaseType, + ServerDeploymentType, ServerModel, StackFilter, StackResponse, UserFilter, UserResponse, WorkspaceResponse, - ServerDeploymentType, ) from zenml.utils.pydantic_utils import before_validator_handler from zenml.zen_stores.zen_store_interface import ZenStoreInterface From 77b7984e766cb4f856c4c83b31f74b2feb2ddb6f Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 30 Dec 2024 16:08:02 +0100 Subject: [PATCH 11/13] Auto-generate an enrollment key if not configured --- .../deploy/helm/templates/NOTES.txt | 22 +++++++++ .../deploy/helm/templates/_environment.tpl | 48 ------------------- .../deploy/helm/templates/server-secret.yaml | 14 ++++-- src/zenml/zen_server/deploy/helm/values.yaml | 3 +- 4 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/zenml/zen_server/deploy/helm/templates/NOTES.txt b/src/zenml/zen_server/deploy/helm/templates/NOTES.txt index 51d34707cd..b3a9e0fc10 100644 --- a/src/zenml/zen_server/deploy/helm/templates/NOTES.txt +++ b/src/zenml/zen_server/deploy/helm/templates/NOTES.txt @@ -1,3 +1,24 @@ +{{- if .Values.zenml.pro.enabled }} + +The ZenML Pro server API is now active and ready to use at the following URL: + + {{ .Values.zenml.serverURL }} + +{{- if .Values.zenml.pro.enrollmentKey }} + +The following enrollment key has been used to enroll your server in the ZenML Pro control plane: + + {{ .Values.zenml.pro.enrollmentKey }} + +{{- else }} + +An enrollment key has been auto-generated for your server. Please use the following command to fetch the enrollment key: + + kubectl get secret {{ include "zenml.fullname" . }} -o jsonpath="{.data.ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET}" | base64 --decode + +{{- end }} + +{{- else }} {{- if .Values.zenml.ingress.enabled }} {{- if .Values.zenml.ingress.host }} @@ -28,3 +49,4 @@ You can get the ZenML server URL by running these commands: {{- end }} {{- end }} +{{- end }} diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index 211f2f3b16..d971667135 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -240,33 +240,6 @@ secure_headers_{{ $key }}: {{ $value | quote }} {{- end }} -{{/* -ZenML server configuration options (secret values). - -This template constructs a dictionary that is similar to the python values that -can be configured in the zenml.config.server_config.ServerConfiguration -class. Only secret values are included in this dictionary. - -The dictionary is then converted into deployment environment variables by other -templates and inserted where it is needed. - -The input is taken from a .ZenML dict that is passed to the template and -contains the values configured in the values.yaml file for the ZenML server. - -Args: - .ZenML: A dictionary with the ZenML configuration values configured for the - ZenML server. -Returns: - A dictionary with the secret values configured for the ZenML server. -*/}} -{{- define "zenml.serverSecretConfigurationAttrs" -}} - -{{- if .ZenML.pro.enabled }} -pro_oauth2_client_secret: {{ .ZenML.pro.enrollmentKey | quote }} -{{- end }} -{{- end }} - - {{/* Server configuration environment variables (non-secret values). @@ -288,27 +261,6 @@ ZENML_SERVER_{{ $k | upper }}: {{ $v | quote }} {{- end }} -{{/* -Server configuration environment variables (secret values). - -Passes the .Values.zenml dict as input to the `zenml.serverSecretConfigurationAttrs` -template and converts the output into a dictionary of environment variables that -need to be configured for the server. - -Args: - .Values: The values.yaml file for the ZenML deployment. -Returns: - A dictionary with the secret environment variables that are configured for - the server (i.e. keys starting with `ZENML_SERVER_`). -*/}} -{{- define "zenml.serverSecretEnvVariables" -}} -{{ $zenml := dict "ZenML" .Values.zenml }} -{{- range $k, $v := include "zenml.serverSecretConfigurationAttrs" $zenml | fromYaml }} -ZENML_SERVER_{{ $k | upper }}: {{ $v | quote }} -{{- end }} -{{- end }} - - {{/* Secrets store configuration options (non-secret values). diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index 4d3e0f2706..8161072afe 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -10,10 +10,18 @@ data: {{- else }} ZENML_SERVER_JWT_SECRET_KEY: {{ $prevServerSecret.data.ZENML_SERVER_JWT_SECRET_KEY | default (randAlphaNum 32 | b64enc | quote) }} {{- end }} - {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}} - {{ $k }}: {{ $v | b64enc | quote }} + + {{- if .ZenML.pro.enabled }} + {{- if .ZenML.pro.enrollmentKey }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ .ZenML.pro.enrollmentKey | b64enc | quote }} + {{- else if or .Release.IsInstall (not $prevServerSecret) }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ randAlphaNum 64 | b64enc | quote }} + {{- else }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ $prevServerSecret.data.ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET | default (randAlphaNum 64 | b64enc | quote) }} + {{- end }} {{- end }} - {{- range $k, $v := include "zenml.serverSecretEnvVariables" . | fromYaml}} + + {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} {{- range $k, $v := include "zenml.secretsStoreSecretEnvVariables" . | fromYaml}} diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index df4401fe2e..dc590139cc 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -56,7 +56,8 @@ zenml: # The name of the ZenML Pro organization to use. organizationName: - # The enrollment key to use for the ZenML Pro tenant. + # The enrollment key to use for the ZenML Pro tenant. If not specified, + # an enrollment key will be auto-generated. enrollmentKey: # The URL where the ZenML server API is reachable. If not specified, the From 8977c6631bc87dedd978b91430af6c992ce75ef4 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 30 Dec 2024 21:11:20 +0100 Subject: [PATCH 12/13] Fix helm chart and API token for stack deployments --- .../deploy/helm/templates/server-secret.yaml | 6 +++--- .../zen_server/routers/stack_deployment_endpoints.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index 8161072afe..00bb658d21 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -11,9 +11,9 @@ data: ZENML_SERVER_JWT_SECRET_KEY: {{ $prevServerSecret.data.ZENML_SERVER_JWT_SECRET_KEY | default (randAlphaNum 32 | b64enc | quote) }} {{- end }} - {{- if .ZenML.pro.enabled }} - {{- if .ZenML.pro.enrollmentKey }} - ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ .ZenML.pro.enrollmentKey | b64enc | quote }} + {{- if .Values.zenml.pro.enabled }} + {{- if .Values.zenml.pro.enrollmentKey }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ .Values.zenml.pro.enrollmentKey | b64enc | quote }} {{- else if or .Release.IsInstall (not $prevServerSecret) }} ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ randAlphaNum 64 | b64enc | quote }} {{- else }} diff --git a/src/zenml/zen_server/routers/stack_deployment_endpoints.py b/src/zenml/zen_server/routers/stack_deployment_endpoints.py index a98a9de54f..0f045502d5 100644 --- a/src/zenml/zen_server/routers/stack_deployment_endpoints.py +++ b/src/zenml/zen_server/routers/stack_deployment_endpoints.py @@ -34,7 +34,7 @@ StackDeploymentInfo, ) from zenml.stack_deployments.utils import get_stack_deployment_class -from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.auth import AuthContext, authorize, generate_access_token from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import verify_permission @@ -114,10 +114,10 @@ def get_stack_deployment_config( assert token is not None # A new API token is generated for the stack deployment - expires = datetime.datetime.utcnow() + datetime.timedelta( - minutes=STACK_DEPLOYMENT_API_TOKEN_EXPIRATION - ) - api_token = token.encode(expires=expires) + api_token = generate_access_token( + user_id=token.user_id, + expires_in=STACK_DEPLOYMENT_API_TOKEN_EXPIRATION * 60, + ).access_token return stack_deployment_class( terraform=terraform, From 603a38d48c5e4dc4a70a3696877e418b6a4cd849 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Sat, 11 Jan 2025 16:04:44 +0100 Subject: [PATCH 13/13] Restored formatting changes --- examples/quickstart/pipelines/training.py | 8 ++--- .../orchestrators/sagemaker_orchestrator.py | 8 ++--- .../wandb_experiment_tracker_flavor.py | 6 ++-- .../zen_server/template_execution/utils.py | 7 ++--- .../1cb6477f72d6_move_artifact_save_type.py | 30 +++++++------------ ...557b2871693_update_step_run_input_types.py | 12 +++----- .../cc269488e5a9_separate_run_metadata.py | 18 ++++------- 7 files changed, 32 insertions(+), 57 deletions(-) diff --git a/examples/quickstart/pipelines/training.py b/examples/quickstart/pipelines/training.py index 2f8e9ff915..55439cc582 100644 --- a/examples/quickstart/pipelines/training.py +++ b/examples/quickstart/pipelines/training.py @@ -47,11 +47,9 @@ def english_translation_pipeline( tokenized_dataset, tokenizer = tokenize_data( dataset=full_dataset, model_type=model_type ) - ( - tokenized_train_dataset, - tokenized_eval_dataset, - tokenized_test_dataset, - ) = split_dataset(tokenized_dataset) + tokenized_train_dataset, tokenized_eval_dataset, tokenized_test_dataset = ( + split_dataset(tokenized_dataset) + ) model = train_model( tokenized_dataset=tokenized_train_dataset, model_type=model_type, diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index 46f89424d9..f832647a97 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -632,11 +632,9 @@ def _compute_orchestrator_url( the URL to the dashboard view in SageMaker. """ try: - ( - region_name, - pipeline_name, - execution_id, - ) = dissect_pipeline_execution_arn(pipeline_execution.arn) + region_name, pipeline_name, execution_id = ( + dissect_pipeline_execution_arn(pipeline_execution.arn) + ) # Get the Sagemaker session session = pipeline_execution.sagemaker_session diff --git a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py index 96f366af79..7a1a732170 100644 --- a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +++ b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py @@ -23,7 +23,7 @@ cast, ) -from pydantic import BaseModel, field_validator +from pydantic import field_validator, BaseModel from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( @@ -69,8 +69,8 @@ def _convert_settings(cls, value: Any) -> Any: import wandb if isinstance(value, wandb.Settings): - # Depending on the wandb version, either `model_dump`, - # `make_static` or `to_dict` is available to convert the settings + # Depending on the wandb version, either `model_dump`, + # `make_static` or `to_dict` is available to convert the settings # to a dictionary if isinstance(value, BaseModel): return value.model_dump() diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index d2bdb5241d..33ef74b644 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -149,10 +149,9 @@ def run_template( ) def _task() -> None: - ( - pypi_requirements, - apt_packages, - ) = requirements_utils.get_requirements_for_stack(stack=stack) + pypi_requirements, apt_packages = ( + requirements_utils.get_requirements_for_stack(stack=stack) + ) if build.python_version: version_info = version.parse(build.python_version) diff --git a/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py b/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py index 3e74d80e14..ff14523cdb 100644 --- a/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py +++ b/src/zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py @@ -23,8 +23,7 @@ def upgrade() -> None: batch_op.add_column(sa.Column("save_type", sa.TEXT(), nullable=True)) # Step 2: Move data from step_run_output_artifact.type to artifact_version.save_type - op.execute( - """ + op.execute(""" UPDATE artifact_version SET save_type = ( SELECT max(step_run_output_artifact.type) @@ -32,22 +31,17 @@ def upgrade() -> None: WHERE step_run_output_artifact.artifact_id = artifact_version.id GROUP BY artifact_id ) - """ - ) - op.execute( - """ + """) + op.execute(""" UPDATE artifact_version SET save_type = 'step_output' WHERE artifact_version.save_type = 'default' - """ - ) - op.execute( - """ + """) + op.execute(""" UPDATE artifact_version SET save_type = 'external' WHERE save_type is NULL - """ - ) + """) # # Step 3: Set save_type to non-nullable with op.batch_alter_table("artifact_version", schema=None) as batch_op: @@ -75,8 +69,7 @@ def downgrade() -> None: ) # Move data back from artifact_version.save_type to step_run_output_artifact.type - op.execute( - """ + op.execute(""" UPDATE step_run_output_artifact SET type = ( SELECT max(artifact_version.save_type) @@ -84,15 +77,12 @@ def downgrade() -> None: WHERE step_run_output_artifact.artifact_id = artifact_version.id GROUP BY artifact_id ) - """ - ) - op.execute( - """ + """) + op.execute(""" UPDATE step_run_output_artifact SET type = 'default' WHERE step_run_output_artifact.type = 'step_output' - """ - ) + """) # Set type to non-nullable with op.batch_alter_table( diff --git a/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py b/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py index 410a481e06..cf397f57d9 100644 --- a/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py +++ b/src/zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py @@ -17,21 +17,17 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - op.execute( - """ + op.execute(""" UPDATE step_run_input_artifact SET type = 'step_output' WHERE type = 'default' - """ - ) + """) def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - op.execute( - """ + op.execute(""" UPDATE step_run_input_artifact SET type = 'default' WHERE type = 'step_output' - """ - ) + """) diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py index 6dcfe8ee8d..52a4cbd8ef 100644 --- a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -41,12 +41,10 @@ def upgrade() -> None: connection = op.get_bind() run_metadata_data = connection.execute( - sa.text( - """ + sa.text(""" SELECT id, resource_id, resource_type FROM run_metadata - """ - ) + """) ).fetchall() # Prepare data with new UUIDs for bulk insert @@ -109,24 +107,20 @@ def downgrade() -> None: # Fetch data from `run_metadata_resource` run_metadata_resource_data = connection.execute( - sa.text( - """ + sa.text(""" SELECT resource_id, resource_type, run_metadata_id FROM run_metadata_resource - """ - ) + """) ).fetchall() # Update `run_metadata` with the data from `run_metadata_resource` for row in run_metadata_resource_data: connection.execute( - sa.text( - """ + sa.text(""" UPDATE run_metadata SET resource_id = :resource_id, resource_type = :resource_type WHERE id = :run_metadata_id - """ - ), + """), { "resource_id": row.resource_id, "resource_type": row.resource_type,