Skip to content

Commit

Permalink
[config-feeder] databricks: extract cluster id from hostname
Browse files Browse the repository at this point in the history
  • Loading branch information
slicklash committed Aug 17, 2023
1 parent dadafab commit 0564824
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 14 deletions.
2 changes: 1 addition & 1 deletion granulate_utils/config_feeder/client/bigdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def get_node_info(logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None) -> Optional[NodeInfo]:
if emr_node_info := get_emr_node_info():
return emr_node_info
if databricks_node_info := get_databricks_node_info():
if databricks_node_info := get_databricks_node_info(logger):
return databricks_node_info
if dataproc_node_info := get_dataproc_node_info(logger):
return dataproc_node_info
Expand Down
46 changes: 34 additions & 12 deletions granulate_utils/config_feeder/client/bigdata/databricks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Dict, List, Optional
import logging
import os
import re
from typing import Dict, List, Optional, Union

from granulate_utils.config_feeder.core.models.cluster import BigDataPlatform, CloudProvider
from granulate_utils.config_feeder.core.models.node import NodeInfo
Expand All @@ -14,25 +17,31 @@
KEY_CLUSTER_ID = f"{CLUSTER_KEY_PREFIX}Id"
KEY_DRIVER_INSTANCE_ID = f"{DRIVER_KEY_PREFIX}InstanceId"

REGEX_CLUSTER_ID = r"(\d{4}-\d{6}-\w{8})"

def get_databricks_node_info() -> Optional[NodeInfo]:

def get_databricks_node_info(
logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None
) -> Optional[NodeInfo]:
"""
Returns Databricks node info
"""
if properties := _get_deploy_conf():
instance_id = properties[KEY_INSTANCE_ID]
driver_instance_id = properties.get(KEY_DRIVER_INSTANCE_ID, "")
provider = _resolve_cloud_provider(properties.get(KEY_CLOUD_PROVIDER, "unknown"))
external_cluster_id = properties[KEY_CLUSTER_ID]
return NodeInfo(
external_id=instance_id,
external_cluster_id=external_cluster_id,
is_master=(instance_id == driver_instance_id),
provider=provider,
bigdata_platform=BigDataPlatform.DATABRICKS,
bigdata_platform_version=get_databricks_version(),
properties=_exclude_keys(properties, [KEY_CLOUD_PROVIDER, KEY_INSTANCE_ID, KEY_CLUSTER_ID]),
)
if external_cluster_id := _resolve_cluster_id(properties):
return NodeInfo(
external_id=instance_id,
external_cluster_id=external_cluster_id,
is_master=(instance_id == driver_instance_id),
provider=provider,
bigdata_platform=BigDataPlatform.DATABRICKS,
bigdata_platform_version=get_databricks_version(),
properties=_exclude_keys(properties, [KEY_CLOUD_PROVIDER, KEY_INSTANCE_ID, KEY_CLUSTER_ID]),
)
elif logger:
logger.error("cannot resolve cluster id")
return None


Expand Down Expand Up @@ -63,6 +72,19 @@ def _get_deploy_conf() -> Optional[Dict[str, str]]:
return None


def _resolve_cluster_id(properties: Dict[str, str]) -> Optional[str]:
"""
If clusterId is not available in deploy.conf, try to extract it from hostname
e.g. 0817-103940-91u12104-10-26-238-244 -> 0817-103940-91u12104
"""
if KEY_CLUSTER_ID in properties:
return properties[KEY_CLUSTER_ID]
if match := re.search(REGEX_CLUSTER_ID, os.uname()[1]):
return match.group(1)
return None


def _resolve_cloud_provider(provider: str) -> CloudProvider:
if provider == "AWS":
return CloudProvider.AWS
Expand Down
19 changes: 19 additions & 0 deletions tests/granulate_utils/config_feeder/fixtures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self) -> None:
self._stdout: Dict[str, bytes | str] = {}
self._requests: List[Tuple[str, str, Dict[str, Any]]] = []
self._contexts: List[Tuple[ContextManager[Any], Callable[[Any], None] | None]] = []
self._hostname: str = ""

@property
def node_info(self) -> NodeInfo:
Expand All @@ -35,6 +36,10 @@ def mock_http_response(self: NodeMockBase, method: str, url: str, response: Dict
self._requests.append((method, url, response))
return self

def mock_hostname(self, hostname: str) -> NodeMockBase:
self._hostname = hostname
return self

def add_context(
self: NodeMockBase, ctx: ContextManager[Any], fn: Callable[[Any], None] | None = None
) -> NodeMockBase:
Expand Down Expand Up @@ -89,6 +94,20 @@ def __enter__(self: NodeMockBase) -> NodeMockBase:
)
self.add_context(Mocker(), self._mock_http_response)

if self._hostname:
self.add_context(
patch(
"os.uname",
return_value=(
"Linux",
self._hostname,
"5.15.0-79-generic",
"#86-Ubuntu SMP Mon Jul 10 16:07:21 UTC 2023",
"x86_64",
),
)
)

for ctx, fn in self._contexts:
value = ctx.__enter__()
if fn is not None:
Expand Down
8 changes: 7 additions & 1 deletion tests/granulate_utils/config_feeder/fixtures/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
*,
provider: CloudProvider = CloudProvider.AWS,
cluster_id: str = "",
hostname: str = "",
instance_id: str = "",
is_master: bool = False,
version: str = "11.3",
Expand All @@ -19,10 +20,12 @@ def __init__(
properties = {
"databricks.instance.metadata.cloudProvider": provider.upper(),
"databricks.instance.metadata.instanceId": instance_id,
"spark.databricks.clusterUsageTags.clusterId": cluster_id,
"spark.databricks.clusterUsageTags.clusterSomeSecretPassword": "password123",
}

if cluster_id:
properties["spark.databricks.clusterUsageTags.clusterId"] = cluster_id

if is_master:
properties["spark.databricks.clusterUsageTags.driverInstanceId"] = driver_instance_id

Expand All @@ -35,3 +38,6 @@ def __init__(
)

self.mock_file("/databricks/DBR_VERSION", version)

if hostname:
self.mock_hostname(hostname)
35 changes: 35 additions & 0 deletions tests/granulate_utils/config_feeder/test_databricks_node_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import Mock

import pytest
from requests_mock.exceptions import NoMockAddress

from granulate_utils.config_feeder.client import get_node_info
from granulate_utils.config_feeder.core.models.cluster import BigDataPlatform, CloudProvider
Expand Down Expand Up @@ -28,3 +31,35 @@ async def test_should_collect_node_info() -> None:
"spark.databricks.clusterUsageTags.clusterSomeSecretPassword": "*****",
},
)


@pytest.mark.asyncio
async def test_should_extract_cluster_id_from_hostname() -> None:
instance_id = "i-000e86ee86c521650"
hostname = "0817-103940-91u12104-10-26-238-244"
with DatabricksNodeMock(
provider=CloudProvider.AWS,
hostname=hostname,
instance_id=instance_id,
is_master=False,
):
assert get_node_info() == NodeInfo(
provider=CloudProvider.AWS,
bigdata_platform=BigDataPlatform.DATABRICKS,
bigdata_platform_version="11.3",
external_id=instance_id,
external_cluster_id="0817-103940-91u12104",
is_master=False,
properties={
"spark.databricks.clusterUsageTags.clusterSomeSecretPassword": "*****",
},
)


@pytest.mark.asyncio
async def test_should_log_cannot_resolve_cluster_id() -> None:
logger = Mock()
with DatabricksNodeMock(hostname="foo"):
with pytest.raises(NoMockAddress):
assert get_node_info(logger) is None
logger.error.assert_called_with("cannot resolve cluster id")

0 comments on commit 0564824

Please sign in to comment.