From 05648245cbb8632714324ec4e986aa0c049d8b9d Mon Sep 17 00:00:00 2001 From: slicklash Date: Thu, 17 Aug 2023 15:15:15 +0300 Subject: [PATCH] [config-feeder] databricks: extract cluster id from hostname --- .../config_feeder/client/bigdata/__init__.py | 2 +- .../client/bigdata/databricks.py | 46 ++++++++++++++----- .../config_feeder/fixtures/base.py | 19 ++++++++ .../config_feeder/fixtures/databricks.py | 8 +++- .../test_databricks_node_info.py | 35 ++++++++++++++ 5 files changed, 96 insertions(+), 14 deletions(-) diff --git a/granulate_utils/config_feeder/client/bigdata/__init__.py b/granulate_utils/config_feeder/client/bigdata/__init__.py index 04048774..cae8bf09 100644 --- a/granulate_utils/config_feeder/client/bigdata/__init__.py +++ b/granulate_utils/config_feeder/client/bigdata/__init__.py @@ -14,7 +14,7 @@ def get_node_info(logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None) -> Optional[NodeInfo]: if emr_node_info := get_emr_node_info(): return emr_node_info - if databricks_node_info := get_databricks_node_info(): + if databricks_node_info := get_databricks_node_info(logger): return databricks_node_info if dataproc_node_info := get_dataproc_node_info(logger): return dataproc_node_info diff --git a/granulate_utils/config_feeder/client/bigdata/databricks.py b/granulate_utils/config_feeder/client/bigdata/databricks.py index e2800643..c0924d30 100644 --- a/granulate_utils/config_feeder/client/bigdata/databricks.py +++ b/granulate_utils/config_feeder/client/bigdata/databricks.py @@ -1,4 +1,7 @@ -from typing import Dict, List, Optional +import logging +import os +import re +from typing import Dict, List, Optional, Union from granulate_utils.config_feeder.core.models.cluster import BigDataPlatform, CloudProvider from granulate_utils.config_feeder.core.models.node import NodeInfo @@ -14,8 +17,12 @@ KEY_CLUSTER_ID = f"{CLUSTER_KEY_PREFIX}Id" KEY_DRIVER_INSTANCE_ID = f"{DRIVER_KEY_PREFIX}InstanceId" +REGEX_CLUSTER_ID = r"(\d{4}-\d{6}-\w{8})" -def get_databricks_node_info() -> Optional[NodeInfo]: + +def get_databricks_node_info( + logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None +) -> Optional[NodeInfo]: """ Returns Databricks node info """ @@ -23,16 +30,18 @@ def get_databricks_node_info() -> Optional[NodeInfo]: instance_id = properties[KEY_INSTANCE_ID] driver_instance_id = properties.get(KEY_DRIVER_INSTANCE_ID, "") provider = _resolve_cloud_provider(properties.get(KEY_CLOUD_PROVIDER, "unknown")) - external_cluster_id = properties[KEY_CLUSTER_ID] - return NodeInfo( - external_id=instance_id, - external_cluster_id=external_cluster_id, - is_master=(instance_id == driver_instance_id), - provider=provider, - bigdata_platform=BigDataPlatform.DATABRICKS, - bigdata_platform_version=get_databricks_version(), - properties=_exclude_keys(properties, [KEY_CLOUD_PROVIDER, KEY_INSTANCE_ID, KEY_CLUSTER_ID]), - ) + if external_cluster_id := _resolve_cluster_id(properties): + return NodeInfo( + external_id=instance_id, + external_cluster_id=external_cluster_id, + is_master=(instance_id == driver_instance_id), + provider=provider, + bigdata_platform=BigDataPlatform.DATABRICKS, + bigdata_platform_version=get_databricks_version(), + properties=_exclude_keys(properties, [KEY_CLOUD_PROVIDER, KEY_INSTANCE_ID, KEY_CLUSTER_ID]), + ) + elif logger: + logger.error("cannot resolve cluster id") return None @@ -63,6 +72,19 @@ def _get_deploy_conf() -> Optional[Dict[str, str]]: return None +def _resolve_cluster_id(properties: Dict[str, str]) -> Optional[str]: + """ + If clusterId is not available in deploy.conf, try to extract it from hostname + + e.g. 0817-103940-91u12104-10-26-238-244 -> 0817-103940-91u12104 + """ + if KEY_CLUSTER_ID in properties: + return properties[KEY_CLUSTER_ID] + if match := re.search(REGEX_CLUSTER_ID, os.uname()[1]): + return match.group(1) + return None + + def _resolve_cloud_provider(provider: str) -> CloudProvider: if provider == "AWS": return CloudProvider.AWS diff --git a/tests/granulate_utils/config_feeder/fixtures/base.py b/tests/granulate_utils/config_feeder/fixtures/base.py index c947d8d5..cc6af0fc 100644 --- a/tests/granulate_utils/config_feeder/fixtures/base.py +++ b/tests/granulate_utils/config_feeder/fixtures/base.py @@ -16,6 +16,7 @@ def __init__(self) -> None: self._stdout: Dict[str, bytes | str] = {} self._requests: List[Tuple[str, str, Dict[str, Any]]] = [] self._contexts: List[Tuple[ContextManager[Any], Callable[[Any], None] | None]] = [] + self._hostname: str = "" @property def node_info(self) -> NodeInfo: @@ -35,6 +36,10 @@ def mock_http_response(self: NodeMockBase, method: str, url: str, response: Dict self._requests.append((method, url, response)) return self + def mock_hostname(self, hostname: str) -> NodeMockBase: + self._hostname = hostname + return self + def add_context( self: NodeMockBase, ctx: ContextManager[Any], fn: Callable[[Any], None] | None = None ) -> NodeMockBase: @@ -89,6 +94,20 @@ def __enter__(self: NodeMockBase) -> NodeMockBase: ) self.add_context(Mocker(), self._mock_http_response) + if self._hostname: + self.add_context( + patch( + "os.uname", + return_value=( + "Linux", + self._hostname, + "5.15.0-79-generic", + "#86-Ubuntu SMP Mon Jul 10 16:07:21 UTC 2023", + "x86_64", + ), + ) + ) + for ctx, fn in self._contexts: value = ctx.__enter__() if fn is not None: diff --git a/tests/granulate_utils/config_feeder/fixtures/databricks.py b/tests/granulate_utils/config_feeder/fixtures/databricks.py index 853539c2..e3f6fbe0 100644 --- a/tests/granulate_utils/config_feeder/fixtures/databricks.py +++ b/tests/granulate_utils/config_feeder/fixtures/databricks.py @@ -10,6 +10,7 @@ def __init__( *, provider: CloudProvider = CloudProvider.AWS, cluster_id: str = "", + hostname: str = "", instance_id: str = "", is_master: bool = False, version: str = "11.3", @@ -19,10 +20,12 @@ def __init__( properties = { "databricks.instance.metadata.cloudProvider": provider.upper(), "databricks.instance.metadata.instanceId": instance_id, - "spark.databricks.clusterUsageTags.clusterId": cluster_id, "spark.databricks.clusterUsageTags.clusterSomeSecretPassword": "password123", } + if cluster_id: + properties["spark.databricks.clusterUsageTags.clusterId"] = cluster_id + if is_master: properties["spark.databricks.clusterUsageTags.driverInstanceId"] = driver_instance_id @@ -35,3 +38,6 @@ def __init__( ) self.mock_file("/databricks/DBR_VERSION", version) + + if hostname: + self.mock_hostname(hostname) diff --git a/tests/granulate_utils/config_feeder/test_databricks_node_info.py b/tests/granulate_utils/config_feeder/test_databricks_node_info.py index 1626a927..8c1e0cb0 100644 --- a/tests/granulate_utils/config_feeder/test_databricks_node_info.py +++ b/tests/granulate_utils/config_feeder/test_databricks_node_info.py @@ -1,4 +1,7 @@ +from unittest.mock import Mock + import pytest +from requests_mock.exceptions import NoMockAddress from granulate_utils.config_feeder.client import get_node_info from granulate_utils.config_feeder.core.models.cluster import BigDataPlatform, CloudProvider @@ -28,3 +31,35 @@ async def test_should_collect_node_info() -> None: "spark.databricks.clusterUsageTags.clusterSomeSecretPassword": "*****", }, ) + + +@pytest.mark.asyncio +async def test_should_extract_cluster_id_from_hostname() -> None: + instance_id = "i-000e86ee86c521650" + hostname = "0817-103940-91u12104-10-26-238-244" + with DatabricksNodeMock( + provider=CloudProvider.AWS, + hostname=hostname, + instance_id=instance_id, + is_master=False, + ): + assert get_node_info() == NodeInfo( + provider=CloudProvider.AWS, + bigdata_platform=BigDataPlatform.DATABRICKS, + bigdata_platform_version="11.3", + external_id=instance_id, + external_cluster_id="0817-103940-91u12104", + is_master=False, + properties={ + "spark.databricks.clusterUsageTags.clusterSomeSecretPassword": "*****", + }, + ) + + +@pytest.mark.asyncio +async def test_should_log_cannot_resolve_cluster_id() -> None: + logger = Mock() + with DatabricksNodeMock(hostname="foo"): + with pytest.raises(NoMockAddress): + assert get_node_info(logger) is None + logger.error.assert_called_with("cannot resolve cluster id")