diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index db39de0..8e959f5 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -414,6 +414,10 @@ def _filter_user_spark_opts(user_spark_opts: Mapping[str, str]) -> MutableMappin } +def get_total_driver_memory_mb(spark_conf: Dict[str, str]) -> int: + return int(utils.get_spark_driver_memory_mb(spark_conf) + utils.get_spark_driver_memory_overhead_mb(spark_conf)) + + class SparkConfBuilder: def __init__(self): diff --git a/service_configuration_lib/utils.py b/service_configuration_lib/utils.py index fcfb6eb..95d5750 100644 --- a/service_configuration_lib/utils.py +++ b/service_configuration_lib/utils.py @@ -11,10 +11,12 @@ from socket import SO_REUSEADDR from socket import socket from socket import SOL_SOCKET +from typing import Dict from typing import Mapping from typing import Tuple import yaml +from typing_extensions import Literal DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml' POD_TEMPLATE_PATH = '/nail/tmp/spark-pt-{file_uuid}.yaml' @@ -24,6 +26,11 @@ EPHEMERAL_PORT_START = 49152 EPHEMERAL_PORT_END = 65535 +MEM_MULTIPLIER = {'k': 1024, 'm': 1024**2, 'g': 1024**3, 't': 1024**4} + +SPARK_DRIVER_MEM_DEFAULT_MB = 2048 +SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT = 0.1 + log = logging.Logger(__name__) log.setLevel(logging.INFO) @@ -148,3 +155,56 @@ def get_runtime_env() -> str: # we could also just crash or return None, but this seems a little easier to find # should we somehow run into this at Yelp return 'unknown' + + +def get_spark_memory_in_unit(mem: str, unit: Literal['k', 'm', 'g', 't']) -> float: + """ + Converts Spark memory to the desired unit. + mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'. + unit can be 'k', 'm', 'g' or 't'. + Returns memory as a float converted to the desired unit. + """ + try: + memory_bytes = float(mem) + except ValueError: + try: + memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]] + except (ValueError, IndexError): + print(f'Unable to parse memory value {mem}.') + raise + memory_unit = memory_bytes / MEM_MULTIPLIER[unit] + return round(memory_unit, 5) + + +def get_spark_driver_memory_mb(spark_conf: Dict[str, str]) -> float: + """ + Returns the Spark driver memory in MB. + """ + # spark_conf is expected to have "spark.driver.memory" since it is a mandatory default from srv-configs. + driver_mem = spark_conf['spark.driver.memory'] + try: + return get_spark_memory_in_unit(str(driver_mem), 'm') + except (ValueError, IndexError): + return SPARK_DRIVER_MEM_DEFAULT_MB + + +def get_spark_driver_memory_overhead_mb(spark_conf: Dict[str, str]) -> float: + """ + Returns the Spark driver memory overhead in bytes. + """ + # Use spark.driver.memoryOverhead if it is set. + try: + driver_mem_overhead = spark_conf['spark.driver.memoryOverhead'] + try: + # spark.driver.memoryOverhead default unit is MB + driver_mem_overhead_mb = float(driver_mem_overhead) + except ValueError: + driver_mem_overhead_mb = get_spark_memory_in_unit(str(driver_mem_overhead), 'm') + # Calculate spark.driver.memoryOverhead based on spark.driver.memory and spark.driver.memoryOverheadFactor. + except Exception: + driver_mem_mb = get_spark_driver_memory_mb(spark_conf) + driver_mem_overhead_factor = float( + spark_conf.get('spark.driver.memoryOverheadFactor', SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT), + ) + driver_mem_overhead_mb = driver_mem_mb * driver_mem_overhead_factor + return round(driver_mem_overhead_mb, 5) diff --git a/setup.py b/setup.py index a3cb438..7f0d2d5 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( name='service-configuration-lib', - version='2.18.19', + version='2.18.20', provides=['service_configuration_lib'], description='Start, stop, and inspect Yelp SOA services', url='https://github.com/Yelp/service_configuration_lib', diff --git a/tests/utils_test.py b/tests/utils_test.py index 6db80a3..9f49145 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -2,11 +2,13 @@ from socket import SO_REUSEADDR from socket import socket as Socket from socket import SOL_SOCKET +from typing import cast from unittest import mock from unittest.mock import mock_open from unittest.mock import patch import pytest +from typing_extensions import Literal from service_configuration_lib import utils from service_configuration_lib.utils import ephemeral_port_reserve_range @@ -74,6 +76,82 @@ def test_generate_pod_template_path(hex_value): assert utils.generate_pod_template_path() == f'/nail/tmp/spark-pt-{hex_value}.yaml' +@pytest.mark.parametrize( + 'mem_str,unit_str,expected_mem', + ( + ('13425m', 'm', 13425), # Simple case + ('138412032', 'm', 132), # Bytes to MB + ('65536k', 'g', 0.0625), # KB to GB + ('1t', 'g', 1024), # TB to GB + ('1.5g', 'm', 1536), # GB to MB with decimal + ('2048k', 'm', 2), # KB to MB + ('0.5g', 'k', 524288), # GB to KB + ('32768m', 't', 0.03125), # MB to TB + ('1.5t', 'm', 1572864), # TB to MB with decimal + ), +) +def test_get_spark_memory_in_unit(mem_str, unit_str, expected_mem): + assert expected_mem == utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str)) + + +@pytest.mark.parametrize( + 'mem_str,unit_str', + [ + ('invalid', 'm'), + ('1024mb', 'g'), + ], +) +def test_get_spark_memory_in_unit_exceptions(mem_str, unit_str): + with pytest.raises((ValueError, IndexError)): + utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str)) + + +@pytest.mark.parametrize( + 'spark_conf,expected_mem', + [ + ({'spark.driver.memory': '13425m'}, 13425), # Simple case + ({'spark.driver.memory': '138412032'}, 132), # Bytes to MB + ({'spark.driver.memory': '65536k'}, 64), # KB to MB + ({'spark.driver.memory': '1g'}, 1024), # GB to MB + ({'spark.driver.memory': 'invalid'}, utils.SPARK_DRIVER_MEM_DEFAULT_MB), # Invalid case + ({'spark.driver.memory': '1.5g'}, 1536), # GB to MB with decimal + ({'spark.driver.memory': '2048k'}, 2), # KB to MB + ({'spark.driver.memory': '0.5t'}, 524288), # TB to MB + ({'spark.driver.memory': '1024m'}, 1024), # MB to MB + ({'spark.driver.memory': '1.5t'}, 1572864), # TB to MB with decimal + ], +) +def test_get_spark_driver_memory_mb(spark_conf, expected_mem): + assert expected_mem == utils.get_spark_driver_memory_mb(spark_conf) + + +@pytest.mark.parametrize( + 'spark_conf,expected_mem_overhead', + [ + ({'spark.driver.memoryOverhead': '1024'}, 1024), # Simple case + ({'spark.driver.memoryOverhead': '1g'}, 1024), # GB to MB + ({'spark.driver.memory': '10240m', 'spark.driver.memoryOverheadFactor': '0.2'}, 2048), # Custom OverheadFactor + ({'spark.driver.memory': '10240m'}, 1024), # Using default overhead factor + ( + {'spark.driver.memory': 'invalid'}, + utils.SPARK_DRIVER_MEM_DEFAULT_MB * utils.SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT, + ), + # Invalid case + ({'spark.driver.memoryOverhead': '1.5g'}, 1536), # GB to MB with decimal + ({'spark.driver.memory': '2048k', 'spark.driver.memoryOverheadFactor': '0.05'}, 0.1), + # KB to MB with custom factor + ({'spark.driver.memory': '0.5t', 'spark.driver.memoryOverheadFactor': '0.15'}, 78643.2), + # TB to MB with custom factor + ({'spark.driver.memory': '1024m', 'spark.driver.memoryOverheadFactor': '0.25'}, 256), + # MB to MB with custom factor + ({'spark.driver.memory': '1.5t', 'spark.driver.memoryOverheadFactor': '0.05'}, 78643.2), + # TB to MB with custom factor + ], +) +def test_get_spark_driver_memory_overhead_mb(spark_conf, expected_mem_overhead): + assert expected_mem_overhead == utils.get_spark_driver_memory_overhead_mb(spark_conf) + + @pytest.fixture def mock_runtimeenv(): with patch('builtins.open', mock_open(read_data=MOCK_ENV_NAME)) as m: