diff --git a/pyproject.toml b/pyproject.toml index a3da850ce7..ece3228bd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,8 @@ orjson = { version = "~3.10.0", optional = true } Jinja2 = { version = "*", optional = true } ipinfo = { version = ">=4.4.3", optional = true } secure = { version = "~0.3.0", optional = true } +tldextract = { version = "~5.1.0", optional = true } +itsdangerous = { version = "~2.2.0", optional = true } # Optional dependencies for project templates copier = { version = ">=8.1.0", optional = true } @@ -189,6 +191,8 @@ server = [ "Jinja2", "ipinfo", "secure", + "tldextract", + "itsdangerous", ] templates = ["copier", "jinja2-time", "ruff", "pyyaml-include"] terraform = ["python-terraform"] diff --git a/src/zenml/cli/login.py b/src/zenml/cli/login.py index 486be5baea..44a012fdbf 100644 --- a/src/zenml/cli/login.py +++ b/src/zenml/cli/login.py @@ -18,7 +18,7 @@ import re import sys import time -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from uuid import UUID import click @@ -34,7 +34,8 @@ IllegalOperationError, ) from zenml.logger import get_logger -from zenml.login.pro.utils import is_zenml_pro_server_url +from zenml.login.credentials import ServerType +from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.web_login import web_login from zenml.zen_server.utils import ( connected_to_local_server, @@ -145,6 +146,7 @@ def connect_to_server( api_key: Optional[str] = None, verify_ssl: Union[str, bool] = True, refresh: bool = False, + pro_server: bool = False, ) -> None: """Connect the client to a ZenML server or a SQL database. @@ -154,6 +156,7 @@ def connect_to_server( verify_ssl: Whether to verify the server's TLS certificate. If a string is passed, it is interpreted as the path to a CA bundle file. refresh: Whether to force a new login flow with the ZenML server. + pro_server: Whether the server is a ZenML Pro server. """ from zenml.login.credentials_store import get_credentials_store from zenml.zen_stores.base_zen_store import BaseZenStore @@ -170,7 +173,12 @@ def connect_to_server( f"Authenticating to ZenML server '{url}' using an API key..." ) credentials_store.set_api_key(url, api_key) - elif not is_zenml_pro_server_url(url): + elif pro_server: + # We don't have to do anything here assuming the user has already + # logged in to the ZenML Pro server using the ZenML Pro web login + # flow. + cli_utils.declare(f"Authenticating to ZenML server '{url}'...") + else: if refresh or not credentials_store.has_valid_authentication(url): cli_utils.declare( f"Authenticating to ZenML server '{url}' using the web " @@ -179,11 +187,6 @@ def connect_to_server( web_login(url=url, verify_ssl=verify_ssl) else: cli_utils.declare(f"Connecting to ZenML server '{url}'...") - else: - # We don't have to do anything here assuming the user has already - # logged in to the ZenML Pro server using the ZenML Pro web login - # flow. - cli_utils.declare(f"Authenticating to ZenML server '{url}'...") rest_store_config = RestZenStoreConfiguration( url=url, @@ -225,6 +228,7 @@ def connect_to_pro_server( pro_server: Optional[str] = None, api_key: Optional[str] = None, refresh: bool = False, + pro_api_url: Optional[str] = None, ) -> None: """Connect the client to a ZenML Pro server. @@ -233,6 +237,7 @@ def connect_to_pro_server( If not provided, the web login flow will be initiated. api_key: The API key to use to authenticate with the ZenML Pro server. refresh: Whether to force a new login flow with the ZenML Pro server. + pro_api_url: The URL for the ZenML Pro API. Raises: ValueError: If incorrect parameters are provided. @@ -243,6 +248,8 @@ def connect_to_pro_server( from zenml.login.pro.client import ZenMLProClient from zenml.login.pro.tenant.models import TenantStatus + pro_api_url = pro_api_url or ZENML_PRO_API_URL + server_id, server_url, server_name = None, None, None login = False if not pro_server: @@ -264,20 +271,15 @@ def connect_to_pro_server( server_name = pro_server else: server_url = pro_server - if not is_zenml_pro_server_url(server_url): - raise ValueError( - f"The URL '{server_url}' does not seem to belong to a ZenML Pro " - "server. Please check the server URL and try again." - ) credentials_store = get_credentials_store() - if not credentials_store.has_valid_pro_authentication(): + if not credentials_store.has_valid_pro_authentication(pro_api_url): # Without valid ZenML Pro credentials, we can only connect to a ZenML # Pro server with an API key and we also need to know the URL of the # server to connect to. if api_key: if server_url: - connect_to_server(server_url, api_key=api_key) + connect_to_server(server_url, api_key=api_key, pro_server=True) return else: raise ValueError( @@ -289,7 +291,9 @@ def connect_to_pro_server( if login or refresh: try: - token = web_login() + token = web_login( + pro_api_url=pro_api_url, + ) except AuthorizationException as e: cli_utils.error(f"Authorization error: {e}") @@ -320,7 +324,7 @@ def connect_to_pro_server( # server argument passed to the command. server_id = UUID(tenant_id) - client = ZenMLProClient() + client = ZenMLProClient(pro_api_url) if server_id: server = client.tenant.get(server_id) @@ -405,7 +409,7 @@ def connect_to_pro_server( f"Connecting to ZenML Pro server: {server.name} [{str(server.id)}] " ) - connect_to_server(server.url, api_key=api_key) + connect_to_server(server.url, api_key=api_key, pro_server=True) # Update the stored server info with more accurate data taken from the # ZenML Pro tenant object. @@ -414,6 +418,42 @@ def connect_to_pro_server( cli_utils.declare(f"Connected to ZenML Pro server: {server.name}.") +def is_pro_server( + url: str, +) -> Tuple[Optional[bool], Optional[str]]: + """Check if the server at the given URL is a ZenML Pro server. + + Args: + url: The URL of the server to check. + + Returns: + True if the server is a ZenML Pro server, False otherwise, and the + extracted pro API URL if the server is a ZenML Pro server, or None if + no information could be extracted. + """ + from zenml.login.credentials_store import get_credentials_store + from zenml.login.server_info import get_server_info + + # First, check the credentials store + credentials_store = get_credentials_store() + credentials = credentials_store.get_credentials(url) + if credentials: + if credentials.type == ServerType.PRO: + return True, credentials.pro_api_url + else: + return False, None + + # Next, make a request to the server itself + server_info = get_server_info(url) + if not server_info: + return None, None + + if server_info.is_pro_server(): + return True, server_info.pro_api_url + + return False, None + + def _fail_if_authentication_environment_variables_set() -> None: """Fail if any of the authentication environment variables are set.""" environment_variables = [ @@ -650,6 +690,13 @@ def _fail_if_authentication_environment_variables_set() -> None: "dashboard on a public domain. Primarily used for accessing the " "dashboard in Colab. Only used when running `zenml login --local`.", ) +@click.option( + "--pro-api-url", + type=str, + default=None, + help="Custom URL for the ZenML Pro API. Useful when connecting " + "to a self-hosted ZenML Pro deployment.", +) def login( server: Optional[str] = None, pro: bool = False, @@ -667,6 +714,7 @@ def login( blocking: bool = False, image: Optional[str] = None, ngrok_token: Optional[str] = None, + pro_api_url: Optional[str] = None, ) -> None: """Connect to a remote ZenML server. @@ -691,6 +739,7 @@ def login( ngrok_token: An ngrok auth token to use for exposing the local ZenML dashboard on a public domain. Primarily used for accessing the dashboard in Colab. + pro_api_url: Custom URL for the ZenML Pro API. """ _fail_if_authentication_environment_variables_set() @@ -715,6 +764,7 @@ def login( connect_to_pro_server( pro_server=server, refresh=True, + pro_api_url=pro_api_url, ) return @@ -740,29 +790,45 @@ def login( ) if server is not None: - if is_zenml_pro_server_url(server) or not re.match( - r"^(https?|mysql)://", server - ): - # The server argument is a ZenML Pro server URL, server name or UUID + if not re.match(r"^(https?|mysql)://", server): + # The server argument is a ZenML Pro server name or UUID connect_to_pro_server( pro_server=server, api_key=api_key_value, refresh=refresh, + pro_api_url=pro_api_url, ) else: - connect_to_server( - url=server, - api_key=api_key_value, - verify_ssl=verify_ssl, - refresh=refresh, - ) + # The server argument is a server URL + + # First, try to discover if the server is a ZenML Pro server or not + server_is_pro, server_pro_api_url = is_pro_server(server) + if server_is_pro: + connect_to_pro_server( + pro_server=server, + api_key=api_key_value, + refresh=refresh, + # Prefer the pro API URL extracted from the server info if + # available + pro_api_url=server_pro_api_url or pro_api_url, + ) + else: + connect_to_server( + url=server, + api_key=api_key_value, + verify_ssl=verify_ssl, + refresh=refresh, + ) elif current_non_local_server: # The server argument is not provided, so we default to # re-authenticating to the current non-local server that the client is # connected to. server = current_non_local_server - if is_zenml_pro_server_url(server): + # First, try to discover if the server is a ZenML Pro server or not + server_is_pro, server_pro_api_url = is_pro_server(server) + + if server_is_pro: cli_utils.declare( "No server argument was provided. Re-authenticating to " "ZenML Pro...\n" @@ -773,6 +839,9 @@ def login( pro_server=server, api_key=api_key_value, refresh=True, + # Prefer the pro API URL extracted from the server info if + # available + pro_api_url=server_pro_api_url or pro_api_url, ) else: cli_utils.declare( @@ -801,6 +870,7 @@ def login( ) connect_to_pro_server( api_key=api_key_value, + pro_api_url=pro_api_url, ) @@ -857,11 +927,19 @@ def login( default=False, type=click.BOOL, ) +@click.option( + "--pro-api-url", + type=str, + default=None, + help="Custom URL for the ZenML Pro API. Useful when disconnecting " + "from a self-hosted ZenML Pro deployment.", +) def logout( server: Optional[str] = None, local: bool = False, clear: bool = False, pro: bool = False, + pro_api_url: Optional[str] = None, ) -> None: """Disconnect from a ZenML server. @@ -870,6 +948,7 @@ def logout( clear: Clear all stored credentials and tokens. local: Disconnect from the local ZenML server. pro: Log out from ZenML Pro. + pro_api_url: Custom URL for the ZenML Pro API. """ from zenml.login.credentials_store import get_credentials_store @@ -885,8 +964,9 @@ def logout( "The `--pro` flag cannot be used with a specific server URL." ) - if credentials_store.has_valid_pro_authentication(): - credentials_store.clear_pro_credentials() + pro_api_url = pro_api_url or ZENML_PRO_API_URL + if credentials_store.has_valid_pro_authentication(pro_api_url): + credentials_store.clear_pro_credentials(pro_api_url) cli_utils.declare("Logged out from ZenML Pro.") else: cli_utils.declare( @@ -894,10 +974,13 @@ def logout( ) if clear: - if is_zenml_pro_server_url(store_cfg.url): + # Try to determine if the client is currently connected to a ZenML + # Pro server with the given pro API URL + credentials = credentials_store.get_credentials(store_cfg.url) + if credentials and credentials.pro_api_url == pro_api_url: gc.set_default_store() - credentials_store.clear_all_pro_tokens() + credentials_store.clear_all_pro_tokens(pro_api_url) cli_utils.declare("Logged out from all ZenML Pro servers.") return @@ -936,10 +1019,11 @@ def logout( assert server is not None - if is_zenml_pro_server_url(server): - gc.set_default_store() - credentials = credentials_store.get_credentials(server) - if credentials and (clear or store_cfg.url == server): + gc.set_default_store() + credentials = credentials_store.get_credentials(server) + + if credentials and (clear or store_cfg.url == server): + if credentials.type == ServerType.PRO: cli_utils.declare( f"Logging out from ZenML Pro server '{credentials.server_name}'." ) @@ -953,14 +1037,6 @@ def logout( "with 'zenml login '." ) else: - cli_utils.declare( - f"The client is not currently connected to the ZenML Pro server " - f"at '{server}'." - ) - else: - gc.set_default_store() - credentials = credentials_store.get_credentials(server) - if credentials and (clear or store_cfg.url == server): cli_utils.declare(f"Logging out from {server}.") if clear: credentials_store.clear_credentials(server_url=server) @@ -970,8 +1046,8 @@ def logout( "to the same server or 'zenml server list' to view other available " "servers that you can connect to with 'zenml login '." ) - else: - cli_utils.declare( - f"The client is not currently connected to the ZenML server at " - f"'{server}'." - ) + else: + cli_utils.declare( + f"The client is not currently connected to the ZenML server at " + f"'{server}'." + ) diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 39c241c377..c39167e44c 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -187,6 +187,7 @@ def status() -> None: """Show details about the current configuration.""" from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient + from zenml.login.pro.constants import ZENML_PRO_API_URL gc = GlobalConfiguration() client = Client() @@ -214,10 +215,11 @@ def status() -> None: if server.type == ServerType.PRO: # If connected to a ZenML Pro server, refresh the server info pro_credentials = credentials_store.get_pro_credentials( - allow_expired=False + pro_api_url=server.pro_api_url or ZENML_PRO_API_URL, + allow_expired=False, ) if pro_credentials: - pro_client = ZenMLProClient() + pro_client = ZenMLProClient(pro_credentials.url) pro_servers = pro_client.tenant.list( url=store_cfg.url, member_only=True ) @@ -551,19 +553,36 @@ def server() -> None: help="Show all ZenML servers, including those that are not running " "and those with an expired authentication.", ) -def server_list(verbose: bool = False, all: bool = False) -> None: +@click.option( + "--pro-api-url", + type=str, + default=None, + help="Custom URL for the ZenML Pro API. Useful when disconnecting " + "from a self-hosted ZenML Pro deployment.", +) +def server_list( + verbose: bool = False, + all: bool = False, + pro_api_url: Optional[str] = None, +) -> None: """List all ZenML servers that this client is authorized to access. Args: verbose: Whether to show verbose output. all: Whether to show all ZenML servers. + pro_api_url: Custom URL for the ZenML Pro API. """ from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient + from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.tenant.models import TenantRead, TenantStatus + pro_api_url = pro_api_url or ZENML_PRO_API_URL + credentials_store = get_credentials_store() - pro_token = credentials_store.get_pro_token(allow_expired=True) + pro_token = credentials_store.get_pro_token( + allow_expired=True, pro_api_url=pro_api_url + ) current_store_config = GlobalConfiguration().store_configuration # The list of ZenML Pro servers kept in the credentials store @@ -583,7 +602,7 @@ def server_list(verbose: bool = False, all: bool = False) -> None: accessible_pro_servers: List[TenantRead] = [] try: - client = ZenMLProClient() + client = ZenMLProClient(pro_api_url) accessible_pro_servers = client.tenant.list(member_only=not all) except AuthorizationException as e: cli_utils.warning(f"ZenML Pro authorization error: {e}") diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index e70a481d76..b5d3375014 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Functionality to support ZenML GlobalConfiguration.""" +"""Functionality to support ZenML Server Configuration.""" import json import os @@ -19,9 +19,17 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + field_validator, + model_validator, +) from zenml.constants import ( + DEFAULT_REPORTABLE_RESOURCES, DEFAULT_ZENML_JWT_TOKEN_ALGORITHM, DEFAULT_ZENML_JWT_TOKEN_LEEWAY, DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING, @@ -44,6 +52,7 @@ DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP, DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE, ENV_ZENML_SERVER_PREFIX, + ENV_ZENML_SERVER_PRO_PREFIX, ) from zenml.enums import AuthScheme from zenml.logger import get_logger @@ -127,10 +136,6 @@ class ServerConfiguration(BaseModel): to use with the `EXTERNAL` authentication scheme. external_user_info_url: The user info URL of an external authenticator service to use with the `EXTERNAL` authentication scheme. - external_cookie_name: The name of the http-only cookie used to store the - bearer token used to authenticate with the external authenticator - service. Must be specified if the `EXTERNAL` authentication scheme - is used. external_server_id: The ID of the ZenML server to use with the `EXTERNAL` authentication scheme. If not specified, the regular ZenML server ID is used. @@ -246,7 +251,7 @@ class ServerConfiguration(BaseModel): server_url: Optional[str] = None dashboard_url: Optional[str] = None root_url_path: str = "" - metadata: Dict[str, Any] = {} + metadata: Dict[str, str] = {} auth_scheme: AuthScheme = AuthScheme.OAUTH2_PASSWORD_BEARER jwt_token_algorithm: str = DEFAULT_ZENML_JWT_TOKEN_ALGORITHM jwt_token_issuer: Optional[str] = None @@ -276,11 +281,11 @@ class ServerConfiguration(BaseModel): external_login_url: Optional[str] = None external_user_info_url: Optional[str] = None - external_cookie_name: Optional[str] = None external_server_id: Optional[UUID] = None rbac_implementation_source: Optional[str] = None feature_gate_implementation_source: Optional[str] = None + reportable_resources: List[str] = [] workload_manager_implementation_source: Optional[str] = None pipeline_run_auth_window: int = ( DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW @@ -370,14 +375,6 @@ def _validate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: "authentication scheme." ) - # If the authentication scheme is set to `EXTERNAL`, the - # external cookie name must be specified. - if not data.get("external_cookie_name"): - raise ValueError( - "The external cookie name must be specified when " - "using the EXTERNAL authentication scheme." - ) - if cors_allow_origins := data.get("cors_allow_origins"): origins = cors_allow_origins.split(",") data["cors_allow_origins"] = origins @@ -552,12 +549,140 @@ def get_server_config(cls) -> "ServerConfiguration": for k, v in os.environ.items(): if v == "": continue + if k.startswith(ENV_ZENML_SERVER_PRO_PREFIX): + # Skip Pro configuration + continue if k.startswith(ENV_ZENML_SERVER_PREFIX): env_server_config[ k[len(ENV_ZENML_SERVER_PREFIX) :].lower() ] = v - return ServerConfiguration(**env_server_config) + server_config = ServerConfiguration(**env_server_config) + + if server_config.deployment_type == ServerDeploymentType.CLOUD: + # If the zenml server is a Pro server, we will apply the Pro + # configuration overrides to the server config automatically. + # TODO: these should be retrieved dynamically from the ZenML Pro + # API. + server_pro_config = ServerProConfiguration.get_server_config() + server_config.auth_scheme = AuthScheme.EXTERNAL + server_config.external_login_url = ( + f"{server_pro_config.dashboard_url}/api/auth/login" + ) + server_config.external_user_info_url = ( + f"{server_pro_config.api_url}/users/authorize_server" + ) + server_config.external_server_id = server_pro_config.tenant_id + server_config.rbac_implementation_source = ( + "zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC" + ) + server_config.feature_gate_implementation_source = "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface" + server_config.reportable_resources = DEFAULT_REPORTABLE_RESOURCES + server_config.dashboard_url = f"{server_pro_config.dashboard_url}/organizations/{server_pro_config.organization_id}/tenants/{server_pro_config.tenant_id}" + server_config.metadata.update( + dict( + account_id=str(server_pro_config.organization_id), + organization_id=str(server_pro_config.organization_id), + tenant_id=str(server_pro_config.tenant_id), + ) + ) + if server_pro_config.tenant_name: + server_config.metadata.update( + dict(tenant_name=server_pro_config.tenant_name) + ) + + extra_cors_allow_origins = [ + server_pro_config.dashboard_url, + server_pro_config.api_url, + ] + if server_config.server_url: + extra_cors_allow_origins.append(server_config.server_url) + if ( + not server_config.cors_allow_origins + or server_config.cors_allow_origins == ["*"] + ): + server_config.cors_allow_origins = extra_cors_allow_origins + else: + server_config.cors_allow_origins += extra_cors_allow_origins + if "*" in server_config.cors_allow_origins: + # Remove the wildcard from the list + server_config.cors_allow_origins.remove("*") + + # Remove duplicates + server_config.cors_allow_origins = list( + set(server_config.cors_allow_origins) + ) + + if server_config.jwt_token_expire_minutes is None: + server_config.jwt_token_expire_minutes = 60 + + return server_config + + model_config = ConfigDict( + # Allow extra attributes from configs of previous ZenML versions to + # permit downgrading + extra="allow", + ) + + +class ServerProConfiguration(BaseModel): + """ZenML Server Pro configuration attributes. + + All these attributes can be set through the environment with the + `ZENML_SERVER_PRO_`-Prefix. E.g. the value of the `ZENML_SERVER_PRO_API_URL` + environment variable will be extracted to api_url. + + Attributes: + api_url: The ZenML Pro API URL. + dashboard_url: The ZenML Pro dashboard URL. + oauth2_client_secret: The ZenML Pro OAuth2 client secret used to + authenticate the ZenML server with the ZenML Pro API. + oauth2_audience: The OAuth2 audience. + organization_id: The ZenML Pro organization ID. + organization_name: The ZenML Pro organization name. + tenant_id: The ZenML Pro tenant ID. + tenant_name: The ZenML Pro tenant name. + """ + + api_url: str + dashboard_url: str + oauth2_client_secret: str + oauth2_audience: str + organization_id: UUID + organization_name: Optional[str] = None + tenant_id: UUID + tenant_name: Optional[str] = None + + @field_validator("api_url", "dashboard_url") + @classmethod + def _strip_trailing_slashes_url(cls, url: str) -> str: + """Strip any trailing slashes on the API URL. + + Args: + url: The API URL. + + Returns: + The API URL with potential trailing slashes removed. + """ + return url.rstrip("/") + + @classmethod + def get_server_config(cls) -> "ServerProConfiguration": + """Get the server Pro configuration. + + Returns: + The server Pro configuration. + """ + env_server_config: Dict[str, Any] = {} + for k, v in os.environ.items(): + if v == "": + continue + if k.startswith(ENV_ZENML_SERVER_PRO_PREFIX): + env_server_config[ + k[len(ENV_ZENML_SERVER_PRO_PREFIX) :].lower() + ] = v + + return ServerProConfiguration(**env_server_config) model_config = ConfigDict( # Allow extra attributes from configs of previous ZenML versions to diff --git a/src/zenml/constants.py b/src/zenml/constants.py index a7a13edb61..c6b628fde5 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -176,11 +176,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # ZenML Server environment variables ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_" +ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_" ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE" ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME" -ENV_ZENML_SERVER_REPORTABLE_RESOURCES = ( - f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES" -) ENV_ZENML_SERVER_AUTO_ACTIVATE = f"{ENV_ZENML_SERVER_PREFIX}AUTO_ACTIVATE" ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK = ( "ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK" @@ -321,14 +319,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS = 30 DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024 -# Configurations to decide which resources report their usage and check for -# entitlement in the case of a cloud deployment. Expected Format is this: -# ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]' -REPORTABLE_RESOURCES: List[str] = handle_json_env_var( - ENV_ZENML_SERVER_REPORTABLE_RESOURCES, - expected_type=list, - default=["pipeline", "pipeline_run", "model"], -) +DEFAULT_REPORTABLE_RESOURCES = ["pipeline", "pipeline_run", "model"] REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline", "pipeline_run"] # API Endpoint paths: diff --git a/src/zenml/login/credentials.py b/src/zenml/login/credentials.py index 9da8e0b01c..ca76eddafb 100644 --- a/src/zenml/login/credentials.py +++ b/src/zenml/login/credentials.py @@ -23,6 +23,7 @@ from zenml.login.pro.constants import ZENML_PRO_API_URL, ZENML_PRO_URL from zenml.login.pro.tenant.models import TenantRead, TenantStatus from zenml.models import ServerModel +from zenml.models.v2.misc.server_models import ServerDeploymentType from zenml.services.service_status import ServiceState from zenml.utils.enum_utils import StrEnum from zenml.utils.string_utils import get_human_readable_time @@ -44,7 +45,6 @@ class APIToken(BaseModel): expires_in: Optional[int] = None expires_at: Optional[datetime] = None leeway: Optional[int] = None - cookie_name: Optional[str] = None device_id: Optional[UUID] = None device_metadata: Optional[Dict[str, Any]] = None @@ -89,13 +89,20 @@ class ServerCredentials(BaseModel): password: Optional[str] = None # Extra server attributes + deployment_type: Optional[ServerDeploymentType] = None server_id: Optional[UUID] = None server_name: Optional[str] = None - organization_name: Optional[str] = None - organization_id: Optional[UUID] = None status: Optional[str] = None version: Optional[str] = None + # Pro server attributes + organization_name: Optional[str] = None + organization_id: Optional[UUID] = None + tenant_name: Optional[str] = None + tenant_id: Optional[UUID] = None + pro_api_url: Optional[str] = None + pro_dashboard_url: Optional[str] = None + @property def id(self) -> str: """Get the server identifier. @@ -114,11 +121,13 @@ def type(self) -> ServerType: Returns: The server type. """ - from zenml.login.pro.utils import is_zenml_pro_server_url - + if self.deployment_type == ServerDeploymentType.CLOUD: + return ServerType.PRO if self.url == ZENML_PRO_API_URL: return ServerType.PRO_API - if self.organization_id or is_zenml_pro_server_url(self.url): + if self.url == self.pro_api_url: + return ServerType.PRO_API + if self.organization_id or self.tenant_id: return ServerType.PRO if urlparse(self.url).hostname in [ "localhost", @@ -139,25 +148,39 @@ def update_server_info( if isinstance(server_info, ServerModel): # The server ID doesn't change during the lifetime of the server self.server_id = self.server_id or server_info.id - # All other attributes can change during the lifetime of the server + self.deployment_type = server_info.deployment_type server_name = ( - server_info.metadata.get("tenant_name") or server_info.name + server_info.pro_tenant_name + or server_info.metadata.get("tenant_name") + or server_info.name ) if server_name: self.server_name = server_name - organization_id = server_info.metadata.get("organization_id") - if organization_id: - self.organization_id = UUID(organization_id) + if server_info.pro_organization_id: + self.organization_id = server_info.pro_organization_id + if server_info.pro_tenant_id: + self.server_id = server_info.pro_tenant_id + if server_info.pro_organization_name: + self.organization_name = server_info.pro_organization_name + if server_info.pro_tenant_name: + self.tenant_name = server_info.pro_tenant_name + if server_info.pro_api_url: + self.pro_api_url = server_info.pro_api_url + if server_info.pro_dashboard_url: + self.pro_dashboard_url = server_info.pro_dashboard_url self.version = server_info.version or self.version # The server information was retrieved from the server itself, so we # can assume that the server is available self.status = "available" else: + self.deployment_type = ServerDeploymentType.CLOUD self.server_id = server_info.id self.server_name = server_info.name self.organization_name = server_info.organization_name self.organization_id = server_info.organization_id + self.tenant_name = server_info.name + self.tenant_id = server_info.id self.status = server_info.status self.version = server_info.version @@ -248,9 +271,10 @@ def dashboard_url(self) -> str: """ if self.organization_id and self.server_id: return ( - ZENML_PRO_URL + (self.pro_dashboard_url or ZENML_PRO_URL) + f"/organizations/{str(self.organization_id)}/tenants/{str(self.server_id)}" ) + return self.url @property @@ -262,8 +286,8 @@ def dashboard_organization_url(self) -> str: """ if self.organization_id: return ( - ZENML_PRO_URL + f"/organizations/{str(self.organization_id)}" - ) + self.pro_dashboard_url or ZENML_PRO_URL + ) + f"/organizations/{str(self.organization_id)}" return "" @property diff --git a/src/zenml/login/credentials_store.py b/src/zenml/login/credentials_store.py index 34bfe440db..efb9c0f895 100644 --- a/src/zenml/login/credentials_store.py +++ b/src/zenml/login/credentials_store.py @@ -25,7 +25,6 @@ from zenml.io import fileio from zenml.logger import get_logger from zenml.login.credentials import APIToken, ServerCredentials, ServerType -from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.tenant.models import TenantRead from zenml.models import OAuthTokenResponse, ServerModel from zenml.utils import yaml_utils @@ -289,10 +288,13 @@ def get_credentials(self, server_url: str) -> Optional[ServerCredentials]: self.check_and_reload_from_file() return self.credentials.get(server_url) - def get_pro_token(self, allow_expired: bool = False) -> Optional[APIToken]: - """Retrieve a valid token from the credentials store for the ZenML Pro API server. + def get_pro_token( + self, pro_api_url: str, allow_expired: bool = False + ) -> Optional[APIToken]: + """Retrieve a valid token from the credentials store for a ZenML Pro API server. Args: + pro_api_url: The URL of the ZenML Pro API server. allow_expired: Whether to allow expired tokens to be returned. The default behavior is to return None if a token does exist but is expired. @@ -300,14 +302,18 @@ def get_pro_token(self, allow_expired: bool = False) -> Optional[APIToken]: Returns: The stored token if it exists and is not expired, None otherwise. """ - return self.get_token(ZENML_PRO_API_URL, allow_expired) + credential = self.get_pro_credentials(pro_api_url, allow_expired) + if credential: + return credential.api_token + return None def get_pro_credentials( - self, allow_expired: bool = False + self, pro_api_url: str, allow_expired: bool = False ) -> Optional[ServerCredentials]: - """Retrieve a valid token from the credentials store for the ZenML Pro API server. + """Retrieve a valid token from the credentials store for a ZenML Pro API server. Args: + pro_api_url: The URL of the ZenML Pro API server. allow_expired: Whether to allow expired tokens to be returned. The default behavior is to return None if a token does exist but is expired. @@ -315,26 +321,47 @@ def get_pro_credentials( Returns: The stored credentials if they exist and are not expired, None otherwise. """ - credential = self.get_credentials(ZENML_PRO_API_URL) + credential = self.get_credentials(pro_api_url) if ( credential + and credential.type == ServerType.PRO_API and credential.api_token and (not credential.api_token.expired or allow_expired) ): return credential return None - def clear_pro_credentials(self) -> None: - """Delete the token from the store for the ZenML Pro API server.""" - self.clear_token(ZENML_PRO_API_URL) + def clear_pro_credentials(self, pro_api_url: str) -> None: + """Delete the token from the store for a ZenML Pro API server. + + Args: + pro_api_url: The URL of the ZenML Pro API server. + """ + self.clear_token(pro_api_url) + + def clear_all_pro_tokens( + self, pro_api_url: str + ) -> List[ServerCredentials]: + """Delete all tokens from the store for ZenML Pro servers connected to a given API server. + + Args: + pro_api_url: The URL of the ZenML Pro API server. - def clear_all_pro_tokens(self) -> None: - """Delete all tokens from the store for ZenML Pro API servers.""" + Returns: + A list of the credentials that were cleared. + """ + credentials_to_clear = [] for server_url, server in self.credentials.copy().items(): - if server.type == ServerType.PRO: + if ( + server.type == ServerType.PRO + and server.pro_api_url + and server.pro_api_url == pro_api_url + ): if server.api_key: continue self.clear_token(server_url) + credentials_to_clear.append(server) + return credentials_to_clear def has_valid_authentication(self, url: str) -> bool: """Check if a valid authentication credential for the given server URL is stored. @@ -357,13 +384,16 @@ def has_valid_authentication(self, url: str) -> bool: token = credential.api_token return token is not None and not token.expired - def has_valid_pro_authentication(self) -> bool: - """Check if a valid token for the ZenML Pro API server is stored. + def has_valid_pro_authentication(self, pro_api_url: str) -> bool: + """Check if a valid token for a ZenML Pro API server is stored. + + Args: + pro_api_url: The URL of the ZenML Pro API server. Returns: bool: True if a valid token is stored, False otherwise. """ - return self.get_token(ZENML_PRO_API_URL) is not None + return self.get_pro_token(pro_api_url) is not None def set_api_key( self, @@ -433,12 +463,14 @@ def set_token( self, server_url: str, token_response: OAuthTokenResponse, + is_zenml_pro: bool = False, ) -> APIToken: """Store an API token received from an OAuth2 server. Args: server_url: The server URL for which the token is to be stored. token_response: Token response received from an OAuth2 server. + is_zenml_pro: Whether the token is for a ZenML Pro server. Returns: APIToken: The stored token. @@ -468,7 +500,6 @@ def set_token( expires_in=token_response.expires_in, expires_at=expires_at, leeway=leeway, - cookie_name=token_response.cookie_name, device_id=token_response.device_id, device_metadata=token_response.device_metadata, ) @@ -477,10 +508,14 @@ def set_token( if credential: credential.api_token = api_token else: - self.credentials[server_url] = ServerCredentials( + credential = self.credentials[server_url] = ServerCredentials( url=server_url, api_token=api_token ) + if is_zenml_pro: + # This is how we encode that the token is for a ZenML Pro server + credential.pro_api_url = server_url + self._save_credentials() return api_token diff --git a/src/zenml/login/pro/client.py b/src/zenml/login/pro/client.py index 516f84833c..cbc6925d5b 100644 --- a/src/zenml/login/pro/client.py +++ b/src/zenml/login/pro/client.py @@ -33,7 +33,6 @@ from zenml.logger import get_logger from zenml.login.credentials import APIToken from zenml.login.credentials_store import get_credentials_store -from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.models import BaseRestAPIModel from zenml.utils.singleton import SingletonMetaClass from zenml.zen_server.exceptions import exception_from_response @@ -60,14 +59,11 @@ class ZenMLProClient(metaclass=SingletonMetaClass): _tenant: Optional["TenantClient"] = None _organization: Optional["OrganizationClient"] = None - def __init__( - self, url: Optional[str] = None, api_token: Optional[APIToken] = None - ) -> None: + def __init__(self, url: str, api_token: Optional[APIToken] = None) -> None: """Initialize the ZenML Pro client. Args: - url: The URL of the ZenML Pro API server. If not provided, the - default ZenML Pro API server URL is used. + url: The URL of the ZenML Pro API server. api_token: The API token to use for authentication. If not provided, the token is fetched from the credentials store. @@ -75,7 +71,7 @@ def __init__( AuthorizationException: If no API token is provided and no token is found in the credentials store. """ - self._url = url or ZENML_PRO_API_URL + self._url = url if api_token is None: logger.debug( "No ZenML Pro API token provided. Fetching from credentials " diff --git a/src/zenml/login/pro/constants.py b/src/zenml/login/pro/constants.py index 4441367d91..e9ea60abf1 100644 --- a/src/zenml/login/pro/constants.py +++ b/src/zenml/login/pro/constants.py @@ -26,9 +26,3 @@ DEFAULT_ZENML_PRO_URL = "https://cloud.zenml.io" ZENML_PRO_URL = os.getenv(ENV_ZENML_PRO_URL, default=DEFAULT_ZENML_PRO_URL) - -ENV_ZENML_PRO_SERVER_SUBDOMAIN = "ZENML_PRO_SERVER_SUBDOMAIN" -DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN = "cloudinfra.zenml.io" -ZENML_PRO_SERVER_SUBDOMAIN = os.getenv( - ENV_ZENML_PRO_SERVER_SUBDOMAIN, default=DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN -) diff --git a/src/zenml/login/pro/tenant/models.py b/src/zenml/login/pro/tenant/models.py index f5efc57a51..166d7f6af1 100644 --- a/src/zenml/login/pro/tenant/models.py +++ b/src/zenml/login/pro/tenant/models.py @@ -76,7 +76,7 @@ class ZenMLServiceStatus(BaseRestAPIModel): class ZenMLServiceRead(BaseRestAPIModel): """Pydantic Model for viewing a ZenML service.""" - configuration: ZenMLServiceConfiguration = Field( + configuration: Optional[ZenMLServiceConfiguration] = Field( description="The service configuration." ) @@ -133,7 +133,9 @@ def version(self) -> Optional[str]: Returns: The ZenML service version. """ - version = self.zenml_service.configuration.version + version = None + if self.zenml_service.configuration: + version = self.zenml_service.configuration.version if self.zenml_service.status and self.zenml_service.status.version: version = self.zenml_service.status.version diff --git a/src/zenml/login/pro/utils.py b/src/zenml/login/pro/utils.py index 256ced18b6..b58941c6e4 100644 --- a/src/zenml/login/pro/utils.py +++ b/src/zenml/login/pro/utils.py @@ -13,37 +13,16 @@ # permissions and limitations under the License. """ZenML Pro login utils.""" -import re - from zenml.logger import get_logger +from zenml.login.credentials import ServerType from zenml.login.credentials_store import get_credentials_store from zenml.login.pro.client import ZenMLProClient -from zenml.login.pro.constants import ZENML_PRO_SERVER_SUBDOMAIN +from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.tenant.models import TenantStatus logger = get_logger(__name__) -def is_zenml_pro_server_url(url: str) -> bool: - """Check if a given URL is a ZenML Pro server. - - Args: - url: URL to check - - Returns: - True if the URL is a ZenML Pro tenant, False otherwise - """ - domain_regex = ZENML_PRO_SERVER_SUBDOMAIN.replace(".", r"\.") - return bool( - re.match( - r"^(https://)?[a-zA-Z0-9-\.]+\.{domain}/?$".format( - domain=domain_regex - ), - url, - ) - ) - - def get_troubleshooting_instructions(url: str) -> str: """Get troubleshooting instructions for a given ZenML Pro server URL. @@ -54,8 +33,15 @@ def get_troubleshooting_instructions(url: str) -> str: Troubleshooting instructions """ credentials_store = get_credentials_store() - if credentials_store.has_valid_pro_authentication(): - client = ZenMLProClient() + + credentials = credentials_store.get_credentials(url) + if credentials and credentials.type == ServerType.PRO: + pro_api_url = credentials.pro_api_url or ZENML_PRO_API_URL + + if pro_api_url and credentials_store.has_valid_pro_authentication( + pro_api_url + ): + client = ZenMLProClient(pro_api_url) try: servers = client.tenant.list(url=url, member_only=False) diff --git a/src/zenml/login/server_info.py b/src/zenml/login/server_info.py new file mode 100644 index 0000000000..ba4a240e0c --- /dev/null +++ b/src/zenml/login/server_info.py @@ -0,0 +1,52 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML server information retrieval.""" + +from typing import Optional + +from zenml.logger import get_logger +from zenml.models import ServerModel +from zenml.zen_stores.rest_zen_store import ( + RestZenStore, + RestZenStoreConfiguration, +) + +logger = get_logger(__name__) + + +def get_server_info(url: str) -> Optional[ServerModel]: + """Retrieve server information from a remote ZenML server. + + Args: + url: The URL of the ZenML server. + + Returns: + The server information or None if the server info could not be fetched. + """ + # Here we try to leverage the existing RestZenStore support to fetch the + # server info and only the server info, which doesn't actually need + # any authentication. + try: + store = RestZenStore( + config=RestZenStoreConfiguration( + url=url, + ) + ) + return store.server_info + except Exception as e: + logger.warning( + f"Failed to fetch server info from the server running at {url}: {e}" + ) + + return None diff --git a/src/zenml/login/web_login.py b/src/zenml/login/web_login.py index 740c4d7552..66f52ee541 100644 --- a/src/zenml/login/web_login.py +++ b/src/zenml/login/web_login.py @@ -34,13 +34,14 @@ from zenml.logger import get_logger from zenml.login.credentials import APIToken from zenml.login.pro.constants import ZENML_PRO_API_URL -from zenml.login.pro.utils import is_zenml_pro_server_url logger = get_logger(__name__) def web_login( - url: Optional[str] = None, verify_ssl: Optional[Union[str, bool]] = None + url: Optional[str] = None, + verify_ssl: Optional[Union[str, bool]] = None, + pro_api_url: Optional[str] = None, ) -> APIToken: """Implements the OAuth2 Device Authorization Grant flow. @@ -61,6 +62,8 @@ def web_login( verify_ssl: Whether to verify the SSL certificate of the OAuth2 server. If a string is passed, it is interpreted as the path to a CA bundle file. + pro_api_url: The URL of the ZenML Pro API server. If not provided, the + default ZenML Pro API server URL is used. Returns: The response returned by the OAuth2 server. @@ -103,16 +106,16 @@ def web_login( if not url: # If no URL is provided, we use the ZenML Pro API server by default zenml_pro = True - url = base_url = ZENML_PRO_API_URL + url = base_url = pro_api_url or ZENML_PRO_API_URL else: # Get rid of any trailing slashes to prevent issues when having double # slashes in the URL url = url.rstrip("/") - if is_zenml_pro_server_url(url): + if pro_api_url: # This is a ZenML Pro server. The device authentication is done # through the ZenML Pro API. zenml_pro = True - base_url = ZENML_PRO_API_URL + base_url = pro_api_url else: base_url = url @@ -240,4 +243,6 @@ def web_login( ) # Save the token in the credentials store - return credentials_store.set_token(url, token_response) + return credentials_store.set_token( + url, token_response, is_zenml_pro=zenml_pro + ) diff --git a/src/zenml/models/v2/misc/auth_models.py b/src/zenml/models/v2/misc/auth_models.py index 2590ef1a2b..ab6f3a7876 100644 --- a/src/zenml/models/v2/misc/auth_models.py +++ b/src/zenml/models/v2/misc/auth_models.py @@ -119,8 +119,8 @@ class OAuthTokenResponse(BaseModel): token_type: str expires_in: Optional[int] = None refresh_token: Optional[str] = None + csrf_token: Optional[str] = None scope: Optional[str] = None - cookie_name: Optional[str] = None device_id: Optional[UUID] = None device_metadata: Optional[Dict[str, Any]] = None diff --git a/src/zenml/models/v2/misc/server_models.py b/src/zenml/models/v2/misc/server_models.py index be548189c3..6b029d4d4e 100644 --- a/src/zenml/models/v2/misc/server_models.py +++ b/src/zenml/models/v2/misc/server_models.py @@ -107,6 +107,42 @@ class ServerModel(BaseModel): title="Timestamp of latest user activity traced on the server.", ) + pro_dashboard_url: Optional[str] = Field( + None, + title="The base URL of the ZenML Pro dashboard to which the server " + "is connected. Only set if the server is a ZenML Pro server.", + ) + + pro_api_url: Optional[str] = Field( + None, + title="The base URL of the ZenML Pro API to which the server is " + "connected. Only set if the server is a ZenML Pro server.", + ) + + pro_organization_id: Optional[UUID] = Field( + None, + title="The ID of the ZenML Pro organization to which the server is " + "connected. Only set if the server is a ZenML Pro server.", + ) + + pro_organization_name: Optional[str] = Field( + None, + title="The name of the ZenML Pro organization to which the server is " + "connected. Only set if the server is a ZenML Pro server.", + ) + + pro_tenant_id: Optional[UUID] = Field( + None, + title="The ID of the ZenML Pro tenant to which the server is connected. " + "Only set if the server is a ZenML Pro server.", + ) + + pro_tenant_name: Optional[str] = Field( + None, + title="The name of the ZenML Pro tenant to which the server is connected. " + "Only set if the server is a ZenML Pro server.", + ) + def is_local(self) -> bool: """Return whether the server is running locally. @@ -119,6 +155,14 @@ def is_local(self) -> bool: # server ID is the same as the local client (user) ID. return self.id == GlobalConfiguration().user_id + def is_pro_server(self) -> bool: + """Return whether the server is a ZenML Pro server. + + Returns: + True if the server is a ZenML Pro server, False otherwise. + """ + return self.deployment_type == ServerDeploymentType.CLOUD + class ServerLoadInfo(BaseModel): """Domain model for ZenML server load information.""" diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 5d67fac7d3..c90c32494c 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -16,8 +16,8 @@ from contextvars import ContextVar from datetime import datetime, timedelta from typing import Callable, Optional, Union -from urllib.parse import urlencode -from uuid import UUID +from urllib.parse import urlencode, urlparse +from uuid import UUID, uuid4 import requests from fastapi import Depends, Response @@ -63,9 +63,15 @@ UserUpdate, ) from zenml.zen_server.cache import cache_result +from zenml.zen_server.csrf import CSRFToken from zenml.zen_server.exceptions import http_exception_from_error from zenml.zen_server.jwt import JWTToken -from zenml.zen_server.utils import server_config, zen_store +from zenml.zen_server.utils import ( + get_zenml_headers, + is_same_or_subdomain, + server_config, + zen_store, +) logger = get_logger(__name__) @@ -174,6 +180,7 @@ def authenticate_credentials( user_name_or_id: Optional[Union[str, UUID]] = None, password: Optional[str] = None, access_token: Optional[str] = None, + csrf_token: Optional[str] = None, activation_token: Optional[str] = None, ) -> AuthContext: """Verify if user authentication credentials are valid. @@ -192,6 +199,7 @@ def authenticate_credentials( user_name_or_id: The username or user ID. password: The password. access_token: The access token. + csrf_token: The CSRF token. activation_token: The activation token. Returns: @@ -253,6 +261,22 @@ def authenticate_credentials( logger.exception(error) raise CredentialsNotValid(error) + if decoded_token.session_id: + if not csrf_token: + error = "Authentication error: missing CSRF token" + logger.error(error) + raise CredentialsNotValid(error) + + decoded_csrf_token = CSRFToken.decode_token(csrf_token) + + if decoded_csrf_token.session_id != decoded_token.session_id: + error = ( + "Authentication error: CSRF token does not match the " + "access token" + ) + logger.error(error) + raise CredentialsNotValid(error) + try: user_model = zen_store().get_user( user_name_or_id=decoded_token.user_id, include_private=True @@ -282,6 +306,14 @@ def authenticate_credentials( device_model: Optional[OAuthDeviceInternalResponse] = None if decoded_token.device_id: + if server_config().auth_scheme in [ + AuthScheme.NO_AUTH, + AuthScheme.EXTERNAL, + ]: + error = "Authentication error: device authorization is not supported." + logger.error(error) + raise CredentialsNotValid(error) + # Access tokens that have been issued for a device are only valid # for that device, so we need to check if the device ID matches any # of the valid devices in the database. @@ -660,6 +692,7 @@ def authenticate_external_user( # Get the user information from the external authenticator user_info_url = config.external_user_info_url headers = {"Authorization": "Bearer " + external_access_token} + headers.update(get_zenml_headers()) query_params = dict(server_id=str(config.get_external_server_id())) try: @@ -817,6 +850,7 @@ def authenticate_api_key( def generate_access_token( user_id: UUID, response: Optional[Response] = None, + request: Optional[Request] = None, device: Optional[OAuthDeviceInternalResponse] = None, api_key: Optional[APIKeyInternalResponse] = None, expires_in: Optional[int] = None, @@ -828,7 +862,11 @@ def generate_access_token( Args: user_id: The ID of the user. - response: The FastAPI response object. + response: The FastAPI response object. If passed, the access + token will also be set as an HTTP only cookie in the response. + request: The FastAPI request object. Used to determine the request + origin and to decide whether to use cross-site security measures for + the access token cookie. device: The device used for authentication. api_key: The service account API key used for authentication. expires_in: The number of seconds until the token expires. If not set, @@ -866,6 +904,46 @@ def generate_access_token( ) expires_in = config.jwt_token_expire_minutes * 60 + # Figure out if this is a same-site request or a cross-site request + same_site = True + if response and request: + # Extract the origin domain from the request; use the referer as a + # fallback + origin_domain: Optional[str] = None + origin = request.headers.get("origin", request.headers.get("referer")) + if origin: + # If the request origin is known, we use it to determine whether + # this is a cross-site request and enable additional security + # measures. + origin_domain = urlparse(origin).netloc + + server_domain: Optional[str] = config.auth_cookie_domain + # If the server's cookie domain is not explicitly set in the + # server's configuration, we use other sources to determine it: + # + # 1. the server's root URL, if set in the server's configuration + # 2. the X-Forwarded-Host header, if set by the reverse proxy + # 3. the request URL, if all else fails + if not server_domain and config.server_url: + server_domain = urlparse(config.server_url).netloc + if not server_domain: + server_domain = request.headers.get( + "x-forwarded-host", request.url.netloc + ) + + # Same-site requests can come from the same domain or from a + # subdomain of the domain used to issue cookies. + if origin_domain and server_domain: + same_site = is_same_or_subdomain(origin_domain, server_domain) + + csrf_token: Optional[str] = None + session_id: Optional[UUID] = None + if not same_site: + # If responding to a cross-site login request, we need to generate and + # sign a CSRF token associated with the authentication session. + session_id = uuid4() + csrf_token = CSRFToken(session_id=session_id).encode() + access_token = JWTToken( user_id=user_id, device_id=device.id if device else None, @@ -873,15 +951,18 @@ def generate_access_token( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, + # Set the session ID if this is a cross-site request + session_id=session_id, ).encode(expires=expires) - if not device and response: + if response: # Also set the access token as an HTTP only cookie in the response response.set_cookie( key=config.get_auth_cookie_name(), value=access_token, httponly=True, - samesite="lax", + secure=not same_site, + samesite="lax" if same_site else "none", max_age=config.jwt_token_expire_minutes * 60 if config.jwt_token_expire_minutes else None, @@ -889,7 +970,10 @@ def generate_access_token( ) return OAuthTokenResponse( - access_token=access_token, expires_in=expires_in, token_type="bearer" + access_token=access_token, + expires_in=expires_in, + token_type="bearer", + csrf_token=csrf_token, ) @@ -945,6 +1029,7 @@ async def __call__(self, request: Request) -> Optional[str]: def oauth2_authentication( + request: Request, token: str = Depends( CookieOAuth2TokenBearer( tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN, @@ -955,14 +1040,18 @@ def oauth2_authentication( Args: token: The JWT bearer token to be authenticated. + request: The FastAPI request object. Returns: The authentication context reflecting the authenticated user. # noqa: DAR401 """ + csrf_token = request.headers.get("X-CSRF-Token") try: - auth_context = authenticate_credentials(access_token=token) + auth_context = authenticate_credentials( + access_token=token, csrf_token=csrf_token + ) except CredentialsNotValid as e: # We want to be very explicit here and return a CredentialsNotValid # exception encoded as a 401 Unauthorized error encoded, so that the diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index 083e690a35..0aff913563 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -1,114 +1,92 @@ """Utils concerning anything concerning the cloud control plane backend.""" -import os from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional import requests -from pydantic import BaseModel, ConfigDict, field_validator from requests.adapters import HTTPAdapter, Retry +from zenml.config.server_config import ServerProConfiguration from zenml.exceptions import SubscriptionUpgradeRequiredError -from zenml.zen_server.utils import server_config - -ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" +from zenml.zen_server.utils import get_zenml_headers, server_config _cloud_connection: Optional["ZenMLCloudConnection"] = None -class ZenMLCloudConfiguration(BaseModel): - """ZenML Pro RBAC configuration.""" - - api_url: str - oauth2_client_id: str - oauth2_client_secret: str - oauth2_audience: str - - @field_validator("api_url") - @classmethod - def _strip_trailing_slashes_url(cls, url: str) -> str: - """Strip any trailing slashes on the API URL. - - Args: - url: The API URL. - - Returns: - The API URL with potential trailing slashes removed. - """ - return url.rstrip("/") - - @classmethod - def from_environment(cls) -> "ZenMLCloudConfiguration": - """Get the RBAC configuration from environment variables. - - Returns: - The RBAC configuration. - """ - env_config: Dict[str, Any] = {} - for k, v in os.environ.items(): - if v == "": - continue - if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): - env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v - - return ZenMLCloudConfiguration(**env_config) - - model_config = ConfigDict( - # Allow extra attributes from configs of previous ZenML versions to - # permit downgrading - extra="allow" - ) - - class ZenMLCloudConnection: """Class to use for communication between server and control plane.""" def __init__(self) -> None: """Initialize the RBAC component.""" - self._config = ZenMLCloudConfiguration.from_environment() + self._config = ServerProConfiguration.get_server_config() self._session: Optional[requests.Session] = None self._token: Optional[str] = None self._token_expires_at: Optional[datetime] = None - def get( - self, endpoint: str, params: Optional[Dict[str, Any]] + def request( + self, + method: str, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, ) -> requests.Response: - """Send a GET request using the active session. + """Send a request using the active session. Args: + method: The HTTP method to use. endpoint: The endpoint to send the request to. This will be appended to the base URL. params: Parameters to include in the request. + data: Data to include in the request. Raises: + SubscriptionUpgradeRequiredError: If the current subscription tier + is insufficient for the attempted operation. RuntimeError: If the request failed. - SubscriptionUpgradeRequiredError: In case the current subscription - tier is insufficient for the attempted operation. Returns: The response. """ url = self._config.api_url + endpoint - response = self.session.get(url=url, params=params, timeout=7) + response = self.session.request( + method=method, url=url, params=params, json=data, timeout=7 + ) if response.status_code == 401: - # If we get an Unauthorized error from the API serer, we refresh the - # auth token and try again + # Refresh the auth token and try again self._clear_session() - response = self.session.get(url=url, params=params, timeout=7) + response = self.session.request( + method=method, url=url, params=params, json=data, timeout=7 + ) try: response.raise_for_status() - except requests.HTTPError: + except requests.HTTPError as e: if response.status_code == 402: raise SubscriptionUpgradeRequiredError(response.json()) else: raise RuntimeError( - f"Failed with the following error {response} {response.text}" + f"Failed while trying to contact the central zenml pro " + f"service: {e}" ) return response + def get( + self, endpoint: str, params: Optional[Dict[str, Any]] + ) -> requests.Response: + """Send a GET request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + + Returns: + The response. + """ + return self.request(method="GET", endpoint=endpoint, params=params) + def post( self, endpoint: str, @@ -123,33 +101,33 @@ def post( params: Parameters to include in the request. data: Data to include in the request. - Raises: - RuntimeError: If the request failed. - Returns: The response. """ - url = self._config.api_url + endpoint - - response = self.session.post( - url=url, params=params, json=data, timeout=7 + return self.request( + method="POST", endpoint=endpoint, params=params, data=data ) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.post( - url=url, params=params, json=data, timeout=7 - ) - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact the central zenml pro " - f"service: {e}" - ) + def patch( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> requests.Response: + """Send a PATCH request using the active session. - return response + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + data: Data to include in the request. + + Returns: + The response. + """ + return self.request( + method="PATCH", endpoint=endpoint, params=params, data=data + ) @property def session(self) -> requests.Session: @@ -169,6 +147,8 @@ def session(self) -> requests.Session: self._session = requests.Session() token = self._fetch_auth_token() self._session.headers.update({"Authorization": "Bearer " + token}) + # Add the ZenML specific headers + self._session.headers.update(get_zenml_headers()) retries = Retry( total=5, backoff_factor=0.1, status_forcelist=[502, 504] @@ -194,7 +174,7 @@ def _clear_session(self) -> None: self._token_expires_at = None def _fetch_auth_token(self) -> str: - """Fetch an auth token for the Cloud API from auth0. + """Fetch an auth token from the Cloud API. Raises: RuntimeError: If the auth token can't be fetched. @@ -210,11 +190,14 @@ def _fetch_auth_token(self) -> str: ): return self._token - # Get an auth token from auth0 + # Get an auth token from the Cloud API login_url = f"{self._config.api_url}/auth/login" headers = {"content-type": "application/x-www-form-urlencoded"} + # Add zenml specific headers to the request + headers.update(get_zenml_headers()) payload = { - "client_id": self._config.oauth2_client_id, + # The client ID is the external server ID + "client_id": str(server_config().get_external_server_id()), "client_secret": self._config.oauth2_client_secret, "audience": self._config.oauth2_audience, "grant_type": "client_credentials", @@ -225,7 +208,9 @@ def _fetch_auth_token(self) -> str: ) response.raise_for_status() except Exception as e: - raise RuntimeError(f"Error fetching auth token from auth0: {e}") + raise RuntimeError( + f"Error fetching auth token from the Cloud API: {e}" + ) json_response = response.json() access_token = json_response.get("access_token", "") @@ -237,7 +222,9 @@ def _fetch_auth_token(self) -> str: or not expires_in or not isinstance(expires_in, int) ): - raise RuntimeError("Could not fetch auth token from auth0.") + raise RuntimeError( + "Could not fetch auth token from the Cloud API." + ) self._token = access_token self._token_expires_at = datetime.now(timezone.utc) + timedelta( @@ -259,3 +246,8 @@ def cloud_connection() -> ZenMLCloudConnection: _cloud_connection = ZenMLCloudConnection() return _cloud_connection + + +def send_pro_tenant_status_update() -> None: + """Send a tenant status update to the Cloud API.""" + cloud_connection().patch("/tenant_status") diff --git a/src/zenml/zen_server/csrf.py b/src/zenml/zen_server/csrf.py new file mode 100644 index 0000000000..9bf51cab6e --- /dev/null +++ b/src/zenml/zen_server/csrf.py @@ -0,0 +1,91 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""CSRF token utilities module for ZenML server.""" + +from uuid import UUID + +from pydantic import BaseModel + +from zenml.exceptions import CredentialsNotValid +from zenml.logger import get_logger +from zenml.zen_server.utils import server_config + +logger = get_logger(__name__) + + +class CSRFToken(BaseModel): + """Pydantic object representing a CSRF token. + + Attributes: + session_id: The id of the authenticated session. + """ + + session_id: UUID + + @classmethod + def decode_token( + cls, + token: str, + ) -> "CSRFToken": + """Decodes a CSRF token. + + Decodes a CSRF access token and returns a `CSRFToken` object with the + information retrieved from its contents. + + Args: + token: The encoded CSRF token. + + Returns: + The decoded CSRF token. + + Raises: + CredentialsNotValid: If the token is invalid. + """ + from itsdangerous import BadData, BadSignature, URLSafeSerializer + + config = server_config() + + serializer = URLSafeSerializer(config.jwt_secret_key) + try: + # Decode and verify the token + data = serializer.loads(token) + except BadSignature as e: + raise CredentialsNotValid( + "Invalid CSRF token: signature mismatch" + ) from e + except BadData as e: + raise CredentialsNotValid("Invalid CSRF token") from e + + try: + return CSRFToken(session_id=UUID(data)) + except ValueError as e: + raise CredentialsNotValid( + "Invalid CSRF token: the session ID is not a valid UUID" + ) from e + + def encode(self) -> str: + """Creates a CSRF token. + + Encodes, signs and returns a CSRF access token. + + Returns: + The generated CSRF token. + """ + from itsdangerous import URLSafeSerializer + + config = server_config() + + serializer = URLSafeSerializer(config.jwt_secret_key) + token = serializer.dumps(str(self.session_id)) + return token diff --git a/src/zenml/zen_server/deploy/helm/templates/NOTES.txt b/src/zenml/zen_server/deploy/helm/templates/NOTES.txt index 51d34707cd..b3a9e0fc10 100644 --- a/src/zenml/zen_server/deploy/helm/templates/NOTES.txt +++ b/src/zenml/zen_server/deploy/helm/templates/NOTES.txt @@ -1,3 +1,24 @@ +{{- if .Values.zenml.pro.enabled }} + +The ZenML Pro server API is now active and ready to use at the following URL: + + {{ .Values.zenml.serverURL }} + +{{- if .Values.zenml.pro.enrollmentKey }} + +The following enrollment key has been used to enroll your server in the ZenML Pro control plane: + + {{ .Values.zenml.pro.enrollmentKey }} + +{{- else }} + +An enrollment key has been auto-generated for your server. Please use the following command to fetch the enrollment key: + + kubectl get secret {{ include "zenml.fullname" . }} -o jsonpath="{.data.ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET}" | base64 --decode + +{{- end }} + +{{- else }} {{- if .Values.zenml.ingress.enabled }} {{- if .Values.zenml.ingress.host }} @@ -28,3 +49,4 @@ You can get the ZenML server URL by running these commands: {{- end }} {{- end }} +{{- end }} diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index 2f64d5881c..d971667135 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -140,8 +140,58 @@ Returns: A dictionary with the non-secret values configured for the ZenML server. */}} {{- define "zenml.serverConfigurationAttrs" -}} + +{{- if .ZenML.pro.enabled }} +deployment_type: cloud +pro_api_url: "{{ .ZenML.pro.apiURL }}" +pro_dashboard_url: "{{ .ZenML.pro.dashboardURL }}" +pro_oauth2_audience: "{{ .ZenML.pro.apiURL }}" +pro_organization_id: "{{ .ZenML.pro.organizationID }}" +pro_tenant_id: "{{ .ZenML.pro.tenantID }}" +{{- if .ZenML.pro.tenantName }} +pro_tenant_name: "{{ .ZenML.pro.tenantName }}" +{{- end }} +{{- if .ZenML.pro.organizationName }} +pro_organization_name: "{{ .ZenML.pro.organizationName }}" +{{- end }} +{{- if .ZenML.pro.extraCorsOrigins }} +cors_allow_origins: "{{ join "," .ZenML.pro.extraCorsOrigins }}" +{{- end }} +{{- if .ZenML.auth.jwtTokenExpireMinutes }} +jwt_token_expire_minutes: {{ .ZenML.auth.jwtTokenExpireMinutes | quote }} +{{- end }} + +{{- else }} + auth_scheme: {{ .ZenML.authType | default .ZenML.auth.authType | quote }} deployment_type: {{ .ZenML.deploymentType | default "kubernetes" }} +{{- if .ZenML.auth.corsAllowOrigins }} +cors_allow_origins: {{ join "," .ZenML.auth.corsAllowOrigins | quote }} +{{- end }} +{{- if .ZenML.auth.externalLoginURL }} +external_login_url: {{ .ZenML.auth.externalLoginURL | quote }} +{{- end }} +{{- if .ZenML.auth.externalUserInfoURL }} +external_user_info_url: {{ .ZenML.auth.externalUserInfoURL | quote }} +{{- end }} +{{- if .ZenML.auth.externalServerID }} +external_server_id: {{ .ZenML.auth.externalServerID | quote }} +{{- end }} +{{- if .ZenML.auth.jwtTokenExpireMinutes }} +jwt_token_expire_minutes: {{ .ZenML.auth.jwtTokenExpireMinutes | quote }} +{{- end }} +{{- if .ZenML.auth.rbacImplementationSource }} +rbac_implementation_source: {{ .ZenML.auth.rbacImplementationSource | quote }} +{{- end }} +{{- if .ZenML.auth.featureGateImplementationSource }} +feature_gate_implementation_source: {{ .ZenML.auth.featureGateImplementationSource | quote }} +{{- end }} +{{- if .ZenML.dashboardURL }} +dashboard_url: {{ .ZenML.dashboardURL | quote }} +{{- end }} + +{{- end }} + {{- if .ZenML.threadPoolSize }} thread_pool_size: {{ .ZenML.threadPoolSize | quote }} {{- end }} @@ -157,18 +207,12 @@ jwt_token_audience: {{ .ZenML.auth.jwtTokenAudience | quote }} {{- if .ZenML.auth.jwtTokenLeewaySeconds }} jwt_token_leeway_seconds: {{ .ZenML.auth.jwtTokenLeewaySeconds | quote }} {{- end }} -{{- if .ZenML.auth.jwtTokenExpireMinutes }} -jwt_token_expire_minutes: {{ .ZenML.auth.jwtTokenExpireMinutes | quote }} -{{- end }} {{- if .ZenML.auth.authCookieName }} auth_cookie_name: {{ .ZenML.auth.authCookieName | quote }} {{- end }} {{- if .ZenML.auth.authCookieDomain }} auth_cookie_domain: {{ .ZenML.auth.authCookieDomain | quote }} {{- end }} -{{- if .ZenML.auth.corsAllowOrigins }} -cors_allow_origins: {{ join "," .ZenML.auth.corsAllowOrigins | quote }} -{{- end }} {{- if .ZenML.auth.maxFailedDeviceAuthAttempts }} max_failed_device_auth_attempts: {{ .ZenML.auth.maxFailedDeviceAuthAttempts | quote }} {{- end }} @@ -184,30 +228,12 @@ device_expiration_minutes: {{ .ZenML.auth.deviceExpirationMinutes | quote }} {{- if .ZenML.auth.trustedDeviceExpirationMinutes }} trusted_device_expiration_minutes: {{ .ZenML.auth.trustedDeviceExpirationMinutes | quote }} {{- end }} -{{- if .ZenML.auth.externalLoginURL }} -external_login_url: {{ .ZenML.auth.externalLoginURL | quote }} -{{- end }} -{{- if .ZenML.auth.externalUserInfoURL }} -external_user_info_url: {{ .ZenML.auth.externalUserInfoURL | quote }} -{{- end }} -{{- if .ZenML.auth.externalCookieName }} -external_cookie_name: {{ .ZenML.auth.externalCookieName | quote }} -{{- end }} -{{- if .ZenML.auth.externalServerID }} -external_server_id: {{ .ZenML.auth.externalServerID | quote }} -{{- end }} {{- if .ZenML.rootUrlPath }} root_url_path: {{ .ZenML.rootUrlPath | quote }} {{- end }} {{- if .ZenML.serverURL }} server_url: {{ .ZenML.serverURL | quote }} {{- end }} -{{- if .ZenML.dashboardURL }} -dashboard_url: {{ .ZenML.dashboardURL | quote }} -{{- end }} -{{- if .ZenML.auth.rbacImplementationSource }} -rbac_implementation_source: {{ .ZenML.auth.rbacImplementationSource | quote }} -{{- end }} {{- range $key, $value := .ZenML.secure_headers }} secure_headers_{{ $key }}: {{ $value | quote }} {{- end }} diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index 9ebf401b51..00bb658d21 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -10,6 +10,17 @@ data: {{- else }} ZENML_SERVER_JWT_SECRET_KEY: {{ $prevServerSecret.data.ZENML_SERVER_JWT_SECRET_KEY | default (randAlphaNum 32 | b64enc | quote) }} {{- end }} + + {{- if .Values.zenml.pro.enabled }} + {{- if .Values.zenml.pro.enrollmentKey }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ .Values.zenml.pro.enrollmentKey | b64enc | quote }} + {{- else if or .Release.IsInstall (not $prevServerSecret) }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ randAlphaNum 64 | b64enc | quote }} + {{- else }} + ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET: {{ $prevServerSecret.data.ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET | default (randAlphaNum 64 | b64enc | quote) }} + {{- end }} + {{- end }} + {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index 5e742c1cb0..dc590139cc 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -27,8 +27,43 @@ zenml: # Overrides the image tag whose default is the chart appVersion. tag: + # ZenML Pro related options. + pro: + # Set `enabled` to true to enable ZenML Pro servers. If set, some of the + # configuration options in the `zenml` section will be overridden with + # values specific to ZenML Pro servers computed from the values set in the + # `pro` section. + enabled: false + + # The URL where the ZenML Pro server API is reachable + apiURL: https://cloudapi.zenml.io + + # The URL where the ZenML Pro dashboard is reachable. + dashboardURL: https://cloud.zenml.io + + # Additional origins to allow in the CORS policy. + extraCorsOrigins: + + # The ID of the ZenML Pro tenant to use. + tenantID: + + # The name of the ZenML Pro tenant to use. + tenantName: + + # The ID of the ZenML Pro organization to use. + organizationID: + + # The name of the ZenML Pro organization to use. + organizationName: + + # The enrollment key to use for the ZenML Pro tenant. If not specified, + # an enrollment key will be auto-generated. + enrollmentKey: + # The URL where the ZenML server API is reachable. If not specified, the # clients will use the same URL used to connect them to the ZenML server. + # + # IMPORTANT: this value must be set for ZenML Pro servers. serverURL: # The URL where the ZenML dashboard is reachable. @@ -39,6 +74,8 @@ zenml: # This is value is used to compute the dashboard URLs during the web login # authentication workflow, to print dashboard URLs in log messages when # running a pipeline and for other similar tasks. + # + # This value is overridden if the `zenml.pro.enabled` value is set. dashboardURL: debug: true @@ -48,6 +85,8 @@ zenml: # ZenML server deployment type. This field is used for telemetry purposes. # Example values are "local", "kubernetes", "aws", "gcp", "azure". + # + # This value is overridden if the `zenml.pro.enabled` value is set. deploymentType: # Authentication settings that control how the ZenML server authenticates @@ -60,6 +99,8 @@ zenml: # HTTP_BASIC - HTTP Basic authentication # OAUTH2_PASSWORD_BEARER - OAuth2 password bearer # EXTERNAL - External authentication (e.g. via a remote authenticator) + # + # This value is overridden if the `zenml.pro.enabled` value is set. authType: OAUTH2_PASSWORD_BEARER # The secret key used to sign JWT tokens. This should be set to @@ -91,7 +132,6 @@ zenml: # ZenML Server ID. jwtTokenIssuer: - # The audience of the JWT tokens. If not specified, the audience is set to # the ZenML Server ID. jwtTokenAudience: @@ -102,6 +142,8 @@ zenml: # The expiration time of JWT tokens in minutes. If not specified, generated # JWT tokens will not be set to expire. + # + # This value is automatically set if the `zenml.pro.enabled` value is set. jwtTokenExpireMinutes: # The name of the http-only cookie used to store the JWT tokens used to @@ -117,20 +159,31 @@ zenml: # The origins allowed to make cross-origin requests to the ZenML server. If # not specified, all origins are allowed. Set this when the ZenML dashboard # is hosted on a different domain than the ZenML server. + # + # This value is overridden if the `zenml.pro.enabled` value is set. corsAllowOrigins: - "*" # The maximum number of failed authentication attempts allowed for an OAuth # 2.0 device before the device is locked. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. maxFailedDeviceAuthAttempts: 3 # The timeout in seconds after which a pending OAuth 2.0 device # authorization request expires. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. deviceAuthTimeout: 300 # The polling interval in seconds used by clients to poll the OAuth 2.0 # device authorization endpoint for the status of a pending device # authorization request. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. deviceAuthPollingInterval: 5 # The time in minutes that an OAuth 2.0 device is allowed to be used to @@ -139,6 +192,9 @@ zenml: # be used indefinitely. This controls the expiration time of the JWT tokens # issued to clients after they have authenticated with the ZenML server # using an OAuth 2.0 device. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. deviceExpirationMinutes: # The time in minutes that a trusted OAuth 2.0 device is allowed to be used @@ -147,33 +203,46 @@ zenml: # be used indefinitely. This controls the expiration time of the JWT tokens # issued to clients after they have authenticated with the ZenML server # using an OAuth 2.0 device that was previously trusted by the user. + # + # This value is ignored if the `zenml.auth.authType` is set to `EXTERNAL` or + # `NO_AUTH`. trustedDeviceExpirationMinutes: # The login URL of an external authenticator service to use with the # `EXTERNAL` authentication scheme. Only relevant if `zenml.auth.authType` # is set to `EXTERNAL`. + # + # This value is overridden if the `zenml.pro.enabled` value is set. externalLoginURL: # The user info URL of an external authenticator service to use with the # `EXTERNAL` authentication scheme. Only relevant if `zenml.auth.authType` # is set to `EXTERNAL`. + # + # This value is overridden if the `zenml.pro.enabled` value is set. externalUserInfoURL: - # The name of the http-only cookie used to store the bearer token used to - # authenticate with the external authenticator service. Only relevant if - # `zenml.auth.authType` is set to `EXTERNAL`. - externalCookieName: - # The UUID of the ZenML server to use with the `EXTERNAL` authentication # scheme. If not specified, the regular ZenML server ID (deployment ID) is # used. + # + # This value is overridden if the `zenml.pro.enabled` value is set. externalServerID: # Source pointing to a class implementing the RBAC interface defined by - # `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, + # `zenml.zen_server.rbac.rbac_interface.RBACInterface`. If not specified, # RBAC will not be enabled for this server. + # + # This value is overridden if the `zenml.pro.enabled` value is set. rbacImplementationSource: + # Source pointing to a class implementing the feature gate interface defined + # by `zenml.zen_server.feature_gate.feature_gate_interface.FeatureGateInterface`. + # If not specified, feature gating will not be enabled for this server. + # + # This value is overridden if the `zenml.pro.enabled` value is set. + featureGateImplementationSource: + # The root URL path to use when behind a proxy. This is useful when the # `rewrite-target` annotation is used in the ingress controller, e.g.: # diff --git a/src/zenml/zen_server/feature_gate/feature_gate_interface.py b/src/zenml/zen_server/feature_gate/feature_gate_interface.py index df4a5d3fc7..5a93fecc34 100644 --- a/src/zenml/zen_server/feature_gate/feature_gate_interface.py +++ b/src/zenml/zen_server/feature_gate/feature_gate_interface.py @@ -20,7 +20,7 @@ class FeatureGateInterface(ABC): - """RBAC interface definition.""" + """Feature gate interface definition.""" @abstractmethod def check_entitlement(self, resource: ResourceType) -> None: diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index eafa61a784..d1150ac028 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Authentication module for ZenML server.""" +"""JWT utilities module for ZenML server.""" from datetime import datetime, timedelta from typing import ( @@ -45,6 +45,7 @@ class JWTToken(BaseModel): issued. step_run_id: The id of the step run for which the token was issued. + session_id: The id of the authenticated session (used for CSRF). claims: The original token claims. """ @@ -54,6 +55,7 @@ class JWTToken(BaseModel): schedule_id: Optional[UUID] = None pipeline_run_id: Optional[UUID] = None step_run_id: Optional[UUID] = None + session_id: Optional[UUID] = None claims: Dict[str, Any] = {} @classmethod @@ -156,6 +158,16 @@ def decode_token( "UUID" ) + session_id: Optional[UUID] = None + if "session_id" in claims: + try: + session_id = UUID(claims.pop("session_id")) + except ValueError: + raise CredentialsNotValid( + "Invalid JWT token: the session_id claim is not a valid " + "UUID" + ) + return JWTToken( user_id=user_id, device_id=device_id, @@ -163,6 +175,7 @@ def decode_token( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, + session_id=session_id, claims=claims, ) @@ -201,6 +214,8 @@ def encode(self, expires: Optional[datetime] = None) -> str: claims["pipeline_run_id"] = str(self.pipeline_run_id) if self.step_run_id: claims["step_run_id"] = str(self.step_run_id) + if self.session_id: + claims["session_id"] = str(self.session_id) return jwt.encode( claims, diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index de2cb958ed..38b8b00499 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -19,7 +19,6 @@ from pydantic import BaseModel from zenml.constants import ( - REPORTABLE_RESOURCES, REQUIRES_CUSTOM_RESOURCE_REPORTING, ) from zenml.exceptions import IllegalOperationError @@ -43,6 +42,7 @@ verify_permission, verify_permission_for_model, ) +from zenml.zen_server.utils import server_config AnyRequest = TypeVar("AnyRequest", bound=BaseRequest) AnyResponse = TypeVar("AnyResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg] @@ -82,7 +82,7 @@ def verify_permissions_and_create_entity( verify_permission(resource_type=resource_type, action=Action.CREATE) needs_usage_increment = ( - resource_type in REPORTABLE_RESOURCES + resource_type in server_config().reportable_resources and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING ) if needs_usage_increment: @@ -129,7 +129,7 @@ def verify_permissions_and_batch_create_entity( verify_permission(resource_type=resource_type, action=Action.CREATE) - if resource_type in REPORTABLE_RESOURCES: + if resource_type in server_config().reportable_resources: raise RuntimeError( "Batch requests are currently not possible with usage-tracked features." ) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index fb0eee4391..6626cfec5c 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -175,6 +175,12 @@ def __init__( self.username = username self.password = password or "" elif self.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE: + if config.auth_scheme in [AuthScheme.NO_AUTH, AuthScheme.EXTERNAL]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Device authorization is not supported.", + ) + if not device_code or not client_id: logger.info("Request with missing device code or client ID") raise HTTPException( @@ -237,6 +243,8 @@ def token( ValueError: If the grant type is invalid. """ config = server_config() + cookie_response: Optional[Response] = response + if auth_form_data.grant_type == OAuthGrantTypes.OAUTH_PASSWORD: auth_context = authenticate_credentials( user_name_or_id=auth_form_data.username, @@ -248,38 +256,34 @@ def token( client_id=auth_form_data.client_id, device_code=auth_form_data.device_code, ) + # API tokens for authorized device are only meant for non-web clients + # and should not be stored as cookies + cookie_response = None + elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_API_KEY: auth_context = authenticate_api_key( api_key=auth_form_data.api_key, ) + # API tokens for API keys are only meant for non-web clients + # and should not be stored as cookies + cookie_response = None elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_EXTERNAL: - assert config.external_cookie_name is not None assert config.external_login_url is not None authorization_url = config.external_login_url - # First, try to get the external access token from the external cookie - external_access_token = request.cookies.get( - config.external_cookie_name - ) - if not external_access_token: - # Next, try to get the external access token from the authorization - # header - authorization_header = request.headers.get("Authorization") - if authorization_header: - scheme, _, token = authorization_header.partition(" ") - if token and scheme.lower() == "bearer": - external_access_token = token - logger.info( - "External access token found in authorization header." - ) - else: - logger.info("External access token found in cookie.") + # Try to get the external session token or authorization token from the + # authorization header + authorization_header = request.headers.get("Authorization") + if authorization_header: + scheme, _, token = authorization_header.partition(" ") + if token and scheme.lower() == "bearer": + external_access_token = token - if not external_access_token: + if not authorization_header: logger.info( - "External access token not found. Redirecting to " + "External session or authorization token not found. Redirecting to " "external authenticator." ) @@ -291,13 +295,18 @@ def token( request=request, ) + # TODO: no easy way to detect which type of client issued this request + # (web or non-web) in order to decide whether to store the access token + # as a cookie in the response or not. For now, we always assume a web + # client. else: # Shouldn't happen, because we verify all grants in the form data raise ValueError("Invalid grant type.") return generate_access_token( user_id=auth_context.user.id, - response=response, + response=cookie_response, + request=request, device=auth_context.device, api_key=auth_context.api_key, ) @@ -351,8 +360,18 @@ def device_authorization( Returns: The device authorization response. + + Raises: + HTTPException: If the device authorization is not supported. """ config = server_config() + + if config.auth_scheme in [AuthScheme.NO_AUTH, AuthScheme.EXTERNAL]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Device authorization is not supported.", + ) + store = zen_store() try: @@ -521,6 +540,8 @@ def api_token( return generate_access_token( user_id=token.user_id, expires_in=expires_in, + # Don't include the access token as a cookie in the response + response=None, ).access_token verify_permission( @@ -621,6 +642,8 @@ def api_token( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, + # Don't include the access token as a cookie in the response + response=None, # Never expire the token expires_in=0, ).access_token diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 124660e5cf..7c2341dffe 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -22,7 +22,6 @@ API, MODEL_VERSIONS, MODELS, - REPORTABLE_RESOURCES, VERSION_1, ) from zenml.models import ( @@ -170,7 +169,7 @@ def delete_model( ) if server_config().feature_gate_enabled: - if ResourceType.MODEL in REPORTABLE_RESOURCES: + if ResourceType.MODEL in server_config().reportable_resources: report_decrement(ResourceType.MODEL, resource_id=model.id) diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index 837c7738de..43a0507ecb 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -20,7 +20,6 @@ from zenml.constants import ( API, PIPELINES, - REPORTABLE_RESOURCES, RUNS, VERSION_1, ) @@ -45,6 +44,7 @@ from zenml.zen_server.utils import ( handle_exceptions, make_dependable, + server_config, zen_store, ) @@ -159,7 +159,7 @@ def delete_pipeline( ) should_decrement = ( - ResourceType.PIPELINE in REPORTABLE_RESOURCES + ResourceType.PIPELINE in server_config().reportable_resources and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) == 0 ) diff --git a/src/zenml/zen_server/routers/stack_deployment_endpoints.py b/src/zenml/zen_server/routers/stack_deployment_endpoints.py index a98a9de54f..0f045502d5 100644 --- a/src/zenml/zen_server/routers/stack_deployment_endpoints.py +++ b/src/zenml/zen_server/routers/stack_deployment_endpoints.py @@ -34,7 +34,7 @@ StackDeploymentInfo, ) from zenml.stack_deployments.utils import get_stack_deployment_class -from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.auth import AuthContext, authorize, generate_access_token from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import verify_permission @@ -114,10 +114,10 @@ def get_stack_deployment_config( assert token is not None # A new API token is generated for the stack deployment - expires = datetime.datetime.utcnow() + datetime.timedelta( - minutes=STACK_DEPLOYMENT_API_TOKEN_EXPIRATION - ) - api_token = token.encode(expires=expires) + api_token = generate_access_token( + user_id=token.user_id, + expires_in=STACK_DEPLOYMENT_API_TOKEN_EXPIRATION * 60, + ).access_token return stack_deployment_class( terraform=terraform, diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index dd26289d05..32af09d2a4 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -27,7 +27,6 @@ PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, - REPORTABLE_RESOURCES, RUN_METADATA, RUN_TEMPLATES, RUNS, @@ -112,6 +111,7 @@ from zenml.zen_server.utils import ( handle_exceptions, make_dependable, + server_config, zen_store, ) @@ -524,7 +524,7 @@ def create_pipeline( # We limit pipeline namespaces, not pipeline versions needs_usage_increment = ( - ResourceType.PIPELINE in REPORTABLE_RESOURCES + ResourceType.PIPELINE in server_config().reportable_resources and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) == 0 ) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 86414385ff..65c8151625 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -20,6 +20,7 @@ TYPE_CHECKING, Any, Callable, + Dict, List, Optional, Tuple, @@ -31,6 +32,7 @@ from pydantic import BaseModel, ValidationError +from zenml import __version__ as zenml_version from zenml.config.global_config import GlobalConfiguration from zenml.config.server_config import ServerConfiguration from zenml.constants import ( @@ -571,3 +573,65 @@ def is_user_request(request: "Request") -> bool: # If none of the above conditions are met, consider it a user request return True + + +def is_same_or_subdomain(source_domain: str, target_domain: str) -> bool: + """Check if the source domain is the same or a subdomain of the target domain. + + Examples: + is_same_or_subdomain("example.com", "example.com") -> True + is_same_or_subdomain("alpha.example.com", "example.com") -> True + is_same_or_subdomain("alpha.example.com", ".example.com") -> True + is_same_or_subdomain("example.com", "alpha.example.com") -> False + is_same_or_subdomain("alpha.beta.example.com", "beta.example.com") -> True + is_same_or_subdomain("alpha.beta.example.com", "alpha.example.com") -> False + is_same_or_subdomain("alphabeta.gamma.example", "beta.gamma.example") -> False + + Args: + source_domain: The source domain to check. + target_domain: The target domain to compare against. + + Returns: + True if the source domain is the same or a subdomain of the target + domain, False otherwise. + """ + import tldextract + + # Extract the registered domain and suffix for both + src_parts = tldextract.extract(source_domain) + tgt_parts = tldextract.extract(target_domain) + + if src_parts == tgt_parts: + return True # Same domain + + # Reconstruct the base domains (e.g., example.com) + src_base_domain = f"{src_parts.domain}.{src_parts.suffix}" + tgt_base_domain = f"{tgt_parts.domain}.{tgt_parts.suffix}" + + if src_base_domain != tgt_base_domain: + return False # Different base domains + + if tgt_parts.subdomain == "": + return True # Subdomain + + if src_parts.subdomain.endswith(f".{tgt_parts.subdomain.lstrip('.')}"): + return True # Subdomain of subdomain + + return False + + +def get_zenml_headers() -> Dict[str, str]: + """Get the ZenML specific headers to be included in requests made by the server. + + Returns: + The ZenML specific headers. + """ + config = server_config() + headers = { + "zenml-server-id": str(config.get_external_server_id()), + "zenml-server-version": zenml_version, + } + if config.server_url: + headers["zenml-server-url"] = config.server_url + + return headers diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 3c4b20638a..060e9db59c 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -54,6 +54,7 @@ ) from zenml.enums import AuthScheme, SourceContextTypes from zenml.models import ServerDeploymentType +from zenml.zen_server.cloud_utils import send_pro_tenant_status_update from zenml.zen_server.exceptions import error_detail from zenml.zen_server.routers import ( actions_endpoints, @@ -389,6 +390,10 @@ def initialize() -> None: initialize_plugins() initialize_secure_headers() initialize_memcache(cfg.memcache_max_capacity, cfg.memcache_default_expiry) + if cfg.deployment_type == ServerDeploymentType.CLOUD: + # Send a tenant status update to the Cloud API to indicate that the + # ZenML server is running or to update the version and server URL. + send_pro_tenant_status_update() DASHBOARD_REDIRECT_URL = None diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 11467c4481..42c812800f 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -47,6 +47,7 @@ from zenml.logger import get_logger from zenml.models import ( ServerDatabaseType, + ServerDeploymentType, ServerModel, StackFilter, StackResponse, @@ -390,7 +391,7 @@ def get_store_info(self) -> ServerModel: secrets_store_type = SecretsStoreType.NONE if isinstance(self, SqlZenStore) and self.config.secrets_store: secrets_store_type = self.config.secrets_store.type - return ServerModel( + store_info = ServerModel( id=GlobalConfiguration().user_id, active=True, version=zenml.__version__, @@ -405,6 +406,23 @@ def get_store_info(self) -> ServerModel: metadata=metadata, ) + # Add ZenML Pro specific store information to the server model, if available. + if store_info.deployment_type == ServerDeploymentType.CLOUD: + from zenml.config.server_config import ServerProConfiguration + + pro_config = ServerProConfiguration.get_server_config() + + store_info.pro_api_url = pro_config.api_url + store_info.pro_dashboard_url = pro_config.dashboard_url + store_info.pro_organization_id = pro_config.organization_id + store_info.pro_tenant_id = pro_config.tenant_id + if pro_config.tenant_name: + store_info.pro_tenant_name = pro_config.tenant_name + if pro_config.organization_name: + store_info.pro_organization_name = pro_config.organization_name + + return store_info + def is_local_store(self) -> bool: """Check if the store is local or connected to a local ZenML server. diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index e0e364987d..d8fe6b0090 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -124,9 +124,9 @@ from zenml.logger import get_logger from zenml.login.credentials import APIToken from zenml.login.credentials_store import get_credentials_store +from zenml.login.pro.constants import ZENML_PRO_API_URL from zenml.login.pro.utils import ( get_troubleshooting_instructions, - is_zenml_pro_server_url, ) from zenml.models import ( ActionFilter, @@ -450,6 +450,7 @@ class RestZenStore(BaseZenStore): CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = RestZenStoreConfiguration _api_token: Optional[APIToken] = None _session: Optional[requests.Session] = None + _server_info: Optional[ServerModel] = None # ==================================== # ZenML Store interface implementation @@ -469,7 +470,7 @@ def _initialize(self) -> None: """ try: client_version = zenml.__version__ - server_version = self.get_store_info().version + server_version = self.server_info.version # Handle cases where the ZenML server is not available except ConnectionError as e: @@ -500,7 +501,7 @@ def _initialize(self) -> None: except Exception as e: zenml_pro_extra = "" - if is_zenml_pro_server_url(self.url): + if ".zenml.io" in self.url: zenml_pro_extra = ( "\nHINT: " + get_troubleshooting_instructions(self.url) ) @@ -522,6 +523,17 @@ def _initialize(self) -> None: ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING, ) + @property + def server_info(self) -> ServerModel: + """Get cached information about the server. + + Returns: + Cached information about the server. + """ + if self._server_info is None: + return self.get_store_info() + return self._server_info + def get_store_info(self) -> ServerModel: """Get information about the server. @@ -529,7 +541,8 @@ def get_store_info(self) -> ServerModel: Information about the server. """ body = self.get(INFO) - return ServerModel.model_validate(body) + self._server_info = ServerModel.model_validate(body) + return self._server_info def get_deployment_id(self) -> UUID: """Get the ID of the deployment. @@ -537,7 +550,7 @@ def get_deployment_id(self) -> UUID: Returns: The ID of the deployment. """ - return self.get_store_info().id + return self.server_info.id # -------------------- Server Settings -------------------- @@ -4028,19 +4041,6 @@ def get_or_generate_api_token(self) -> str: token = credentials.api_token if credentials else None if credentials and token and not token.expired: self._api_token = token - - # Populate the server info in the credentials store if it is - # not already present - if not credentials.server_id: - try: - server_info = self.get_store_info() - except Exception as e: - logger.warning(f"Failed to get server info: {e}.") - else: - credentials_store.update_server_info( - self.url, server_info - ) - return self._api_token.access_token # Token is expired or not found in the cache. Time to get a new one. @@ -4089,13 +4089,23 @@ def get_or_generate_api_token(self) -> str: "username": username, "password": password, } - elif is_zenml_pro_server_url(self.url): + elif self.server_info.is_pro_server(): # ZenML Pro tenants use a proprietary authorization grant # where the ZenML Pro API session token is exchanged for a # regular ZenML server access token. # Get the ZenML Pro API session token, if cached and valid - pro_token = credentials_store.get_pro_token(allow_expired=True) + + # We need to determine the right ZenML Pro API URL to use + pro_api_url = self.server_info.pro_api_url + if not pro_api_url and credentials and credentials.pro_api_url: + pro_api_url = credentials.pro_api_url + if not pro_api_url: + pro_api_url = ZENML_PRO_API_URL + + pro_token = credentials_store.get_pro_token( + pro_api_url, allow_expired=True + ) if not pro_token: raise CredentialsNotValid( "You need to be logged in to ZenML Pro in order to "