From 2d20c2ed5911ae4d5c73820b7697ba4558fd0158 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 16 Dec 2024 19:32:49 +0100 Subject: [PATCH] Allow connecting to external servers enrolled as ZenML Pro tenants --- src/zenml/cli/login.py | 18 ++++++------ src/zenml/models/v2/misc/server_models.py | 8 ++++++ src/zenml/zen_server/auth.py | 2 +- src/zenml/zen_stores/rest_zen_store.py | 34 +++++++++++------------ 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/src/zenml/cli/login.py b/src/zenml/cli/login.py index 486be5baea..4cddc93bd3 100644 --- a/src/zenml/cli/login.py +++ b/src/zenml/cli/login.py @@ -145,6 +145,7 @@ def connect_to_server( api_key: Optional[str] = None, verify_ssl: Union[str, bool] = True, refresh: bool = False, + pro_server: bool = False, ) -> None: """Connect the client to a ZenML server or a SQL database. @@ -154,6 +155,7 @@ def connect_to_server( verify_ssl: Whether to verify the server's TLS certificate. If a string is passed, it is interpreted as the path to a CA bundle file. refresh: Whether to force a new login flow with the ZenML server. + pro_server: Whether the server is a ZenML Pro server. """ from zenml.login.credentials_store import get_credentials_store from zenml.zen_stores.base_zen_store import BaseZenStore @@ -170,7 +172,12 @@ def connect_to_server( f"Authenticating to ZenML server '{url}' using an API key..." ) credentials_store.set_api_key(url, api_key) - elif not is_zenml_pro_server_url(url): + elif pro_server: + # We don't have to do anything here assuming the user has already + # logged in to the ZenML Pro server using the ZenML Pro web login + # flow. + cli_utils.declare(f"Authenticating to ZenML server '{url}'...") + else: if refresh or not credentials_store.has_valid_authentication(url): cli_utils.declare( f"Authenticating to ZenML server '{url}' using the web " @@ -179,11 +186,6 @@ def connect_to_server( web_login(url=url, verify_ssl=verify_ssl) else: cli_utils.declare(f"Connecting to ZenML server '{url}'...") - else: - # We don't have to do anything here assuming the user has already - # logged in to the ZenML Pro server using the ZenML Pro web login - # flow. - cli_utils.declare(f"Authenticating to ZenML server '{url}'...") rest_store_config = RestZenStoreConfiguration( url=url, @@ -277,7 +279,7 @@ def connect_to_pro_server( # server to connect to. if api_key: if server_url: - connect_to_server(server_url, api_key=api_key) + connect_to_server(server_url, api_key=api_key, pro_server=True) return else: raise ValueError( @@ -405,7 +407,7 @@ def connect_to_pro_server( f"Connecting to ZenML Pro server: {server.name} [{str(server.id)}] " ) - connect_to_server(server.url, api_key=api_key) + connect_to_server(server.url, api_key=api_key, pro_server=True) # Update the stored server info with more accurate data taken from the # ZenML Pro tenant object. diff --git a/src/zenml/models/v2/misc/server_models.py b/src/zenml/models/v2/misc/server_models.py index be548189c3..d7e619020a 100644 --- a/src/zenml/models/v2/misc/server_models.py +++ b/src/zenml/models/v2/misc/server_models.py @@ -119,6 +119,14 @@ def is_local(self) -> bool: # server ID is the same as the local client (user) ID. return self.id == GlobalConfiguration().user_id + def is_pro_server(self) -> bool: + """Return whether the server is a ZenML Pro server. + + Returns: + True if the server is a ZenML Pro server, False otherwise. + """ + return self.deployment_type == ServerDeploymentType.CLOUD + class ServerLoadInfo(BaseModel): """Domain model for ZenML server load information.""" diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 1b2a1b976b..f47a56f56c 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -1003,12 +1003,12 @@ async def __call__(self, request: Request) -> Optional[str]: def oauth2_authentication( + request: Request, token: str = Depends( CookieOAuth2TokenBearer( tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN, ) ), - request: Request = Depends(), ) -> AuthContext: """Authenticates any request to the ZenML server with OAuth2 JWT tokens. diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index e3e29759b2..5dcdb4881a 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -450,6 +450,7 @@ class RestZenStore(BaseZenStore): CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = RestZenStoreConfiguration _api_token: Optional[APIToken] = None _session: Optional[requests.Session] = None + _server_info: Optional[ServerModel] = None # ==================================== # ZenML Store interface implementation @@ -469,7 +470,7 @@ def _initialize(self) -> None: """ try: client_version = zenml.__version__ - server_version = self.get_store_info().version + server_version = self.server_info.version # Handle cases where the ZenML server is not available except ConnectionError as e: @@ -522,6 +523,17 @@ def _initialize(self) -> None: ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING, ) + @property + def server_info(self) -> ServerModel: + """Get cached information about the server. + + Returns: + Cached information about the server. + """ + if self._server_info is None: + return self.get_store_info() + return self._server_info + def get_store_info(self) -> ServerModel: """Get information about the server. @@ -529,7 +541,8 @@ def get_store_info(self) -> ServerModel: Information about the server. """ body = self.get(INFO) - return ServerModel.model_validate(body) + self._server_info = ServerModel.model_validate(body) + return self._server_info def get_deployment_id(self) -> UUID: """Get the ID of the deployment. @@ -537,7 +550,7 @@ def get_deployment_id(self) -> UUID: Returns: The ID of the deployment. """ - return self.get_store_info().id + return self.server_info.id # -------------------- Server Settings -------------------- @@ -4028,19 +4041,6 @@ def get_or_generate_api_token(self) -> str: token = credentials.api_token if credentials else None if credentials and token and not token.expired: self._api_token = token - - # Populate the server info in the credentials store if it is - # not already present - if not credentials.server_id: - try: - server_info = self.get_store_info() - except Exception as e: - logger.warning(f"Failed to get server info: {e}.") - else: - credentials_store.update_server_info( - self.url, server_info - ) - return self._api_token.access_token # Token is expired or not found in the cache. Time to get a new one. @@ -4084,7 +4084,7 @@ def get_or_generate_api_token(self) -> str: "username": username, "password": password, } - elif is_zenml_pro_server_url(self.url): + elif self.server_info.is_pro_server(): # ZenML Pro tenants use a proprietary authorization grant # where the ZenML Pro API session token is exchanged for a # regular ZenML server access token.