Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLCOMPUTE-1497 | add methods to get total driver memory including overhead #146

Merged
merged 6 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ def _filter_user_spark_opts(user_spark_opts: Mapping[str, str]) -> MutableMappin
}


def get_total_driver_memory_mb(spark_conf: Dict[str, str]) -> int:
return int(utils.get_spark_driver_memory_mb(spark_conf) + utils.get_spark_driver_memory_overhead_mb(spark_conf))


class SparkConfBuilder:

def __init__(self):
Expand Down
60 changes: 60 additions & 0 deletions service_configuration_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from socket import SO_REUSEADDR
from socket import socket
from socket import SOL_SOCKET
from typing import Dict
from typing import Mapping
from typing import Tuple

import yaml
from typing_extensions import Literal

DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
POD_TEMPLATE_PATH = '/nail/tmp/spark-pt-{file_uuid}.yaml'
Expand All @@ -24,6 +26,11 @@
EPHEMERAL_PORT_START = 49152
EPHEMERAL_PORT_END = 65535

MEM_MULTIPLIER = {'k': 1024, 'm': 1024**2, 'g': 1024**3, 't': 1024**4}

SPARK_DRIVER_MEM_DEFAULT_MB = 2048
SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT = 0.1


log = logging.Logger(__name__)
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -148,3 +155,56 @@ def get_runtime_env() -> str:
# we could also just crash or return None, but this seems a little easier to find
# should we somehow run into this at Yelp
return 'unknown'


def get_spark_memory_in_unit(mem: str, unit: Literal['k', 'm', 'g', 't']) -> float:
"""
Converts Spark memory to the desired unit.
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
unit can be 'k', 'm', 'g' or 't'.
Returns memory as a float converted to the desired unit.
"""
try:
memory_bytes = float(mem)
except ValueError:
try:
memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
except (ValueError, IndexError):
print(f'Unable to parse memory value {mem}.')
raise
memory_unit = memory_bytes / MEM_MULTIPLIER[unit]
return round(memory_unit, 5)


def get_spark_driver_memory_mb(spark_conf: Dict[str, str]) -> float:
"""
Returns the Spark driver memory in MB.
"""
# spark_conf is expected to have "spark.driver.memory" since it is a mandatory default from srv-configs.
driver_mem = spark_conf['spark.driver.memory']
try:
return get_spark_memory_in_unit(str(driver_mem), 'm')
except (ValueError, IndexError):
return SPARK_DRIVER_MEM_DEFAULT_MB


def get_spark_driver_memory_overhead_mb(spark_conf: Dict[str, str]) -> float:
"""
Returns the Spark driver memory overhead in bytes.
"""
# Use spark.driver.memoryOverhead if it is set.
try:
driver_mem_overhead = spark_conf['spark.driver.memoryOverhead']
try:
# spark.driver.memoryOverhead default unit is MB
driver_mem_overhead_mb = float(driver_mem_overhead)
except ValueError:
driver_mem_overhead_mb = get_spark_memory_in_unit(str(driver_mem_overhead), 'm')
# Calculate spark.driver.memoryOverhead based on spark.driver.memory and spark.driver.memoryOverheadFactor.
except Exception:
driver_mem_mb = get_spark_driver_memory_mb(spark_conf)
driver_mem_overhead_factor = float(
spark_conf.get('spark.driver.memoryOverheadFactor', SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT),
)
driver_mem_overhead_mb = driver_mem_mb * driver_mem_overhead_factor
return round(driver_mem_overhead_mb, 5)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup(
name='service-configuration-lib',
version='2.18.19',
version='2.18.20',
provides=['service_configuration_lib'],
description='Start, stop, and inspect Yelp SOA services',
url='https://github.com/Yelp/service_configuration_lib',
Expand Down
78 changes: 78 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from socket import SO_REUSEADDR
from socket import socket as Socket
from socket import SOL_SOCKET
from typing import cast
from unittest import mock
from unittest.mock import mock_open
from unittest.mock import patch

import pytest
from typing_extensions import Literal

from service_configuration_lib import utils
from service_configuration_lib.utils import ephemeral_port_reserve_range
Expand Down Expand Up @@ -74,6 +76,82 @@ def test_generate_pod_template_path(hex_value):
assert utils.generate_pod_template_path() == f'/nail/tmp/spark-pt-{hex_value}.yaml'


@pytest.mark.parametrize(
'mem_str,unit_str,expected_mem',
(
('13425m', 'm', 13425), # Simple case
('138412032', 'm', 132), # Bytes to MB
('65536k', 'g', 0.0625), # KB to GB
('1t', 'g', 1024), # TB to GB
('1.5g', 'm', 1536), # GB to MB with decimal
('2048k', 'm', 2), # KB to MB
('0.5g', 'k', 524288), # GB to KB
('32768m', 't', 0.03125), # MB to TB
('1.5t', 'm', 1572864), # TB to MB with decimal
),
)
def test_get_spark_memory_in_unit(mem_str, unit_str, expected_mem):
assert expected_mem == utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))


@pytest.mark.parametrize(
'mem_str,unit_str',
[
('invalid', 'm'),
('1024mb', 'g'),
],
)
def test_get_spark_memory_in_unit_exceptions(mem_str, unit_str):
with pytest.raises((ValueError, IndexError)):
utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))


@pytest.mark.parametrize(
'spark_conf,expected_mem',
[
({'spark.driver.memory': '13425m'}, 13425), # Simple case
({'spark.driver.memory': '138412032'}, 132), # Bytes to MB
({'spark.driver.memory': '65536k'}, 64), # KB to MB
({'spark.driver.memory': '1g'}, 1024), # GB to MB
({'spark.driver.memory': 'invalid'}, utils.SPARK_DRIVER_MEM_DEFAULT_MB), # Invalid case
({'spark.driver.memory': '1.5g'}, 1536), # GB to MB with decimal
({'spark.driver.memory': '2048k'}, 2), # KB to MB
({'spark.driver.memory': '0.5t'}, 524288), # TB to MB
({'spark.driver.memory': '1024m'}, 1024), # MB to MB
({'spark.driver.memory': '1.5t'}, 1572864), # TB to MB with decimal
],
)
def test_get_spark_driver_memory_mb(spark_conf, expected_mem):
assert expected_mem == utils.get_spark_driver_memory_mb(spark_conf)


@pytest.mark.parametrize(
'spark_conf,expected_mem_overhead',
[
({'spark.driver.memoryOverhead': '1024'}, 1024), # Simple case
({'spark.driver.memoryOverhead': '1g'}, 1024), # GB to MB
({'spark.driver.memory': '10240m', 'spark.driver.memoryOverheadFactor': '0.2'}, 2048), # Custom OverheadFactor
({'spark.driver.memory': '10240m'}, 1024), # Using default overhead factor
(
{'spark.driver.memory': 'invalid'},
utils.SPARK_DRIVER_MEM_DEFAULT_MB * utils.SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT,
),
# Invalid case
({'spark.driver.memoryOverhead': '1.5g'}, 1536), # GB to MB with decimal
({'spark.driver.memory': '2048k', 'spark.driver.memoryOverheadFactor': '0.05'}, 0.1),
# KB to MB with custom factor
({'spark.driver.memory': '0.5t', 'spark.driver.memoryOverheadFactor': '0.15'}, 78643.2),
# TB to MB with custom factor
({'spark.driver.memory': '1024m', 'spark.driver.memoryOverheadFactor': '0.25'}, 256),
# MB to MB with custom factor
({'spark.driver.memory': '1.5t', 'spark.driver.memoryOverheadFactor': '0.05'}, 78643.2),
# TB to MB with custom factor
],
)
def test_get_spark_driver_memory_overhead_mb(spark_conf, expected_mem_overhead):
assert expected_mem_overhead == utils.get_spark_driver_memory_overhead_mb(spark_conf)


@pytest.fixture
def mock_runtimeenv():
with patch('builtins.open', mock_open(read_data=MOCK_ENV_NAME)) as m:
Expand Down
Loading