Skip to content

Commit

Permalink
MLCOMPUTE-1160 | consolidate tron and paasta logic to service config …
Browse files Browse the repository at this point in the history
…lib (#139)

* MLCOMPUTE-1160 | consolidate tron and paasta logic here

* MLCOMPUTE-1160 | bump up version and fix tests

* MLCOMPUTE-1160 | move Spark executor pod template to srv configs

* MLCOMPUTE-1160 | adding tests

* MLCOMPUTE-1160 | fix pod template formatting

---------

Co-authored-by: Sameer Sharma <[email protected]>
  • Loading branch information
CaptainSame and Sameer Sharma authored Feb 15, 2024
1 parent 768dda5 commit c8f0180
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 74 deletions.
125 changes: 67 additions & 58 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import base64
import functools
import hashlib
import itertools
import json
import logging
Expand Down Expand Up @@ -265,6 +263,7 @@ def _get_k8s_spark_env(
paasta_service: str,
paasta_instance: str,
docker_img: str,
pod_template_path: str,
volumes: Optional[List[Mapping[str, str]]],
paasta_pool: str,
driver_ui_port: int,
Expand All @@ -275,9 +274,9 @@ def _get_k8s_spark_env(
# RFC 1123: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names
# technically only paasta instance can be longer than 63 chars. But we apply the normalization regardless.
# NOTE: this affects only k8s labels, not the pod names.
_paasta_cluster = _get_k8s_resource_name_limit_size_with_hash(paasta_cluster)
_paasta_service = _get_k8s_resource_name_limit_size_with_hash(paasta_service)
_paasta_instance = _get_k8s_resource_name_limit_size_with_hash(paasta_instance)
_paasta_cluster = utils.get_k8s_resource_name_limit_size_with_hash(paasta_cluster)
_paasta_service = utils.get_k8s_resource_name_limit_size_with_hash(paasta_service)
_paasta_instance = utils.get_k8s_resource_name_limit_size_with_hash(paasta_instance)
user = os.environ.get('USER', '_unspecified_')

spark_env = {
Expand All @@ -302,6 +301,7 @@ def _get_k8s_spark_env(
'spark.kubernetes.executor.label.yelp.com/pool': paasta_pool,
'spark.kubernetes.executor.label.paasta.yelp.com/pool': paasta_pool,
'spark.kubernetes.executor.label.yelp.com/owner': 'core_ml',
'spark.kubernetes.executor.podTemplateFile': pod_template_path,
**_get_k8s_docker_volumes_conf(volumes),
}
if service_account_name is not None:
Expand Down Expand Up @@ -347,24 +347,6 @@ def _get_local_spark_env(
}


def _get_k8s_resource_name_limit_size_with_hash(name: str, limit: int = 63, suffix: int = 4) -> str:
""" Returns `name` unchanged if it's length does not exceed the `limit`.
Otherwise, returns truncated `name` with it's hash of size `suffix`
appended.
base32 encoding is chosen as it satisfies the common requirement in
various k8s names to be alphanumeric.
NOTE: This function is the same as paasta/paasta_tools/kubernetes_tools.py
"""
if len(name) > limit:
digest = hashlib.md5(name.encode()).digest()
hash = base64.b32encode(digest).decode().replace('=', '').lower()
return f'{name[:(limit-suffix-1)]}-{hash[:suffix]}'
else:
return name


def stringify_spark_env(spark_env: Mapping[str, str]) -> str:
return ' '.join([f'--conf {k}={v}' for k, v in spark_env.items()])

Expand Down Expand Up @@ -913,8 +895,8 @@ def get_history_url(self, spark_conf: Mapping[str, str]) -> Optional[str]:
def _append_event_log_conf(
self,
spark_opts: Dict[str, str],
access_key: Optional[str],
secret_key: Optional[str],
access_key: Optional[str] = None,
secret_key: Optional[str] = None,
session_token: Optional[str] = None,
) -> Dict[str, str]:
enabled = spark_opts.setdefault('spark.eventLog.enabled', 'true').lower()
Expand All @@ -929,31 +911,44 @@ def _append_event_log_conf(

if len(self.spark_srv_conf.items()) == 0:
log.warning('spark_srv_conf is empty, disable event log')
spark_opts.update({'spark.eventLog.enabled': 'false'})
return spark_opts

try:
account_id = (
boto3.client(
'sts',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token,
if access_key:
try:
account_id = (
boto3.client(
'sts',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token,
)
.get_caller_identity()
.get('Account')
)
.get_caller_identity()
.get('Account')
)
except Exception as e:
log.warning('Failed to identify account ID, error: {}'.format(str(e)))
spark_opts['spark.eventLog.enabled'] = 'false'
return spark_opts
except Exception as e:
log.warning('Failed to identify account ID, error: {}'.format(str(e)))
spark_opts['spark.eventLog.enabled'] = 'false'
return spark_opts

for conf in self.spark_srv_conf.get('environments', {}).values():
if account_id == conf['account_id']:
spark_opts['spark.eventLog.enabled'] = 'true'
spark_opts['spark.eventLog.dir'] = conf['default_event_log_dir']
return spark_opts

for conf in self.spark_srv_conf.get('environments', {}).values():
if account_id == conf['account_id']:
spark_opts['spark.eventLog.enabled'] = 'true'
spark_opts['spark.eventLog.dir'] = conf['default_event_log_dir']
else:
environment_config = self.spark_srv_conf.get('environments', {}).get(
utils.get_runtime_env(),
)
if environment_config:
spark_opts.update({
'spark.eventLog.enabled': 'true',
'spark.eventLog.dir': environment_config['default_event_log_dir'],
})
return spark_opts

log.warning(f'Disable event log because No preset event log dir for account: {account_id}')
log.warning('Disable event log because No preset event log dir')
spark_opts['spark.eventLog.enabled'] = 'false'
return spark_opts

Expand Down Expand Up @@ -1014,7 +1009,7 @@ def get_spark_conf(
paasta_service: str,
paasta_instance: str,
docker_img: str,
aws_creds: Tuple[Optional[str], Optional[str], Optional[str]],
aws_creds: Optional[Tuple[Optional[str], Optional[str], Optional[str]]] = None,
extra_volumes: Optional[List[Mapping[str, str]]] = None,
use_eks: bool = False,
k8s_server_address: Optional[str] = None,
Expand Down Expand Up @@ -1080,10 +1075,10 @@ def get_spark_conf(
spark_conf = {**(spark_opts_from_env or {}), **_filter_user_spark_opts(user_spark_opts)}
random_postfix = utils.get_random_string(4)

if aws_creds[2] is not None:
if aws_creds is not None and aws_creds[2] is not None:
spark_conf['spark.hadoop.fs.s3a.aws.credentials.provider'] = AWS_ENV_CREDENTIALS_PROVIDER

# app_name from env is already appended port and time to make it unique
# app_name from env is already appended with port and time to make it unique
app_name = (spark_opts_from_env or {}).get('spark.app.name')
if not app_name:
app_name = f'{app_base_name}_{ui_port}_{int(time.time())}_{random_postfix}'
Expand All @@ -1097,10 +1092,10 @@ def get_spark_conf(
raw_app_id = app_name
else:
raw_app_id = f'{paasta_service}__{paasta_instance}__{random_postfix}'
app_id = re.sub(r'[\.,-]', '_', _get_k8s_resource_name_limit_size_with_hash(raw_app_id))
app_id = re.sub(r'[\.,-]', '_', utils.get_k8s_resource_name_limit_size_with_hash(raw_app_id))

# Starting Spark 3.4+, spark-app-name label has been added. Limiting to 63 characters
app_name = _get_k8s_resource_name_limit_size_with_hash(app_name)
app_name = utils.get_k8s_resource_name_limit_size_with_hash(app_name)

spark_conf.update({
'spark.app.name': app_name,
Expand All @@ -1113,15 +1108,24 @@ def get_spark_conf(
spark_conf, cluster_manager, paasta_pool, force_spark_resource_configs,
)

# Add pod template file
pod_template_path = utils.generate_pod_template_path()
try:
utils.create_pod_template(pod_template_path, app_base_name)
except Exception as e:
log.error(f'Failed to generate Spark executor pod template: {e}')
pod_template_path = ''

if cluster_manager == 'kubernetes':
spark_conf.update(_get_k8s_spark_env(
paasta_cluster,
paasta_service,
paasta_instance,
docker_img,
extra_volumes,
paasta_pool,
ui_port,
paasta_cluster=paasta_cluster,
paasta_service=paasta_service,
paasta_instance=paasta_instance,
docker_img=docker_img,
pod_template_path=pod_template_path,
volumes=extra_volumes,
paasta_pool=paasta_pool,
driver_ui_port=ui_port,
service_account_name=service_account_name,
include_self_managed_configs=not use_eks,
k8s_server_address=k8s_server_address,
Expand All @@ -1146,7 +1150,10 @@ def get_spark_conf(
spark_conf = self._append_spark_prometheus_conf(spark_conf)

# configure spark_event_log
spark_conf = self._append_event_log_conf(spark_conf, *aws_creds)
if aws_creds:
spark_conf = self._append_event_log_conf(spark_conf, *aws_creds)
else:
spark_conf = self._append_event_log_conf(spark_conf)

# configure sql shuffle partitions
spark_conf = self._append_sql_partitions_conf(spark_conf)
Expand All @@ -1158,7 +1165,9 @@ def get_spark_conf(
if is_jupyter:
spark_conf = _append_spark_config(spark_conf, 'spark.ui.showConsoleProgress', 'true')

spark_conf = _append_aws_credentials_conf(spark_conf, *aws_creds, aws_region)
if aws_creds:
spark_conf = _append_aws_credentials_conf(spark_conf, *aws_creds, aws_region)

return spark_conf


Expand Down
58 changes: 57 additions & 1 deletion service_configuration_lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import base64
import contextlib
import errno
import hashlib
import logging
import random
import string
import uuid
from functools import lru_cache
from socket import error as SocketError
from socket import SO_REUSEADDR
from socket import socket
Expand All @@ -12,9 +16,12 @@

import yaml


DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'

POD_TEMPLATE_PATH = '/nail/tmp/spark-pt-{file_uuid}.yaml'

SPARK_EXECUTOR_POD_TEMPLATE = '/nail/srv/configs/spark_executor_pod_template.yaml'

log = logging.Logger(__name__)
log.setLevel(logging.INFO)

Expand Down Expand Up @@ -85,3 +92,52 @@ def ephemeral_port_reserve_range(preferred_port_start: int, preferred_port_end:

def get_random_string(length: int) -> str:
return ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length))


def generate_pod_template_path() -> str:
return POD_TEMPLATE_PATH.format(file_uuid=uuid.uuid4().hex)


def create_pod_template(pod_template_path: str, app_base_name: str) -> None:
try:
with open(SPARK_EXECUTOR_POD_TEMPLATE, 'r') as fp:
parsed_pod_template = fp.read()
parsed_pod_template = parsed_pod_template.format(spark_pod_label=get_k8s_resource_name_limit_size_with_hash(
f'exec-{app_base_name}',
))
parsed_pod_template = yaml.safe_load(parsed_pod_template)
with open(pod_template_path, 'w') as f:
yaml.dump(parsed_pod_template, f)
except Exception as e:
log.warning(f'Failed to read and process {SPARK_EXECUTOR_POD_TEMPLATE}: {e}')
raise e


def get_k8s_resource_name_limit_size_with_hash(name: str, limit: int = 63, suffix: int = 4) -> str:
""" Returns `name` unchanged if it's length does not exceed the `limit`.
Otherwise, returns truncated `name` with its hash of size `suffix`
appended.
base32 encoding is chosen as it satisfies the common requirement in
various k8s names to be alphanumeric.
NOTE: This function is the same as paasta/paasta_tools/kubernetes_tools.py
"""
if len(name) > limit:
digest = hashlib.md5(name.encode()).digest()
hashed = base64.b32encode(digest).decode().replace('=', '').lower()
return f'{name[:(limit-suffix-1)]}-{hashed[:suffix]}'
else:
return name


@lru_cache(maxsize=1)
def get_runtime_env() -> str:
try:
with open('/nail/etc/runtimeenv', mode='r') as f:
return f.read()
except OSError:
log.error('Unable to read runtimeenv - this is not expected if inside Yelp.')
# we could also just crash or return None, but this seems a little easier to find
# should we somehow run into this at Yelp
return 'unknown'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup(
name='service-configuration-lib',
version='2.18.11',
version='2.18.12',
provides=['service_configuration_lib'],
description='Start, stop, and inspect Yelp SOA services',
url='https://github.com/Yelp/service_configuration_lib',
Expand Down
34 changes: 20 additions & 14 deletions tests/spark_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class TestGetSparkConf:
executor_cores = '10'
spark_app_base_name = 'test_app_base_name'
aws_provider_key = 'spark.hadoop.fs.s3a.aws.credentials.provider'
pod_template_path = 'test_pod_template_path'

@pytest.fixture
def mock_spark_srv_conf_file(self, tmpdir, monkeypatch):
Expand Down Expand Up @@ -1072,6 +1073,18 @@ def mock_update_spark_srv_configs(self):
with MockConfigFunction(spark_config.SparkConfBuilder, 'update_spark_srv_configs', return_value) as m:
yield m

@pytest.fixture
def mock_generate_pod_template_path(self):
return_value = self.pod_template_path
with mock.patch.object(utils, 'generate_pod_template_path', return_value=return_value) as m:
yield m

@pytest.fixture
def mock_create_pod_template(self):
return_value = None
with mock.patch.object(utils, 'create_pod_template', return_value=return_value) as m:
yield m

@pytest.fixture
def mock_secret(self, tmpdir, monkeypatch):
secret = 'secret'
Expand Down Expand Up @@ -1228,6 +1241,7 @@ def assert_kubernetes_conf(self, base_volumes, ui_port, mock_ephemeral_port_rese
'spark.kubernetes.executor.label.yelp.com/pool': self.pool,
'spark.kubernetes.executor.label.paasta.yelp.com/pool': self.pool,
'spark.kubernetes.executor.label.yelp.com/owner': 'core_ml',
'spark.kubernetes.executor.podTemplateFile': self.pod_template_path,
}
for i, volume in enumerate(base_volumes + self._get_k8s_base_volumes()):
expected_output[f'spark.kubernetes.executor.volumes.hostPath.{i}.mount.path'] = volume['containerPath']
Expand Down Expand Up @@ -1257,6 +1271,8 @@ def test_leaders_get_spark_conf_kubernetes(
mock_update_spark_srv_configs,
mock_spark_srv_conf_file,
mock_ephemeral_port_reserve_range,
mock_generate_pod_template_path,
mock_create_pod_template,
mock_time,
assert_ui_port,
assert_app_name,
Expand Down Expand Up @@ -1351,6 +1367,8 @@ def test_show_console_progress_jupyter(
mock_get_dra_configs,
mock_spark_srv_conf_file,
mock_ephemeral_port_reserve_range,
mock_generate_pod_template_path,
mock_create_pod_template,
mock_time,
assert_ui_port,
assert_app_name,
Expand Down Expand Up @@ -1393,6 +1411,8 @@ def test_local_spark(
mock_update_spark_srv_configs,
mock_spark_srv_conf_file,
mock_ephemeral_port_reserve_range,
mock_generate_pod_template_path,
mock_create_pod_template,
mock_time,
assert_ui_port,
assert_app_name,
Expand Down Expand Up @@ -1722,17 +1742,3 @@ def test_send_and_calculate_resources_cost(
mock_clusterman_metrics.util.costs.estimate_cost_per_hour.assert_called_once_with(
cluster='test-cluster', pool='test-pool', cpus=10, mem=2048,
)


@pytest.mark.parametrize(
'instance_name,expected_instance_label',
(
('my_job.do_something', 'my_job.do_something'),
(
f"my_job.{'a'* 100}",
'my_job.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-6xhe',
),
),
)
def test_get_k8s_resource_name_limit_size_with_hash(instance_name, expected_instance_label):
assert expected_instance_label == spark_config._get_k8s_resource_name_limit_size_with_hash(instance_name)
Loading

0 comments on commit c8f0180

Please sign in to comment.