diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index 99fa94b..af22a46 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -32,6 +32,7 @@ GPUS_HARD_LIMIT = 15 CLUSTERMAN_METRICS_YAML_FILE_PATH = '/nail/srv/configs/clusterman_metrics.yaml' CLUSTERMAN_YAML_FILE_PATH = '/nail/srv/configs/clusterman.yaml' +SPARK_TRON_JOB_USER = 'TRON' NON_CONFIGURABLE_SPARK_OPTS = { 'spark.master', @@ -295,7 +296,7 @@ def _get_k8s_spark_env( paasta_service: str, paasta_instance: str, docker_img: str, - pod_template_path: str, + pod_template_path: Optional[str], volumes: Optional[List[Mapping[str, str]]], paasta_pool: str, driver_ui_port: int, @@ -335,9 +336,12 @@ def _get_k8s_spark_env( 'spark.kubernetes.executor.label.yelp.com/pool': paasta_pool, 'spark.kubernetes.executor.label.paasta.yelp.com/pool': paasta_pool, 'spark.kubernetes.executor.label.yelp.com/owner': 'core_ml', - 'spark.kubernetes.executor.podTemplateFile': pod_template_path, **_get_k8s_docker_volumes_conf(volumes), } + + if pod_template_path is not None: + spark_env['spark.kubernetes.executor.podTemplateFile'] = pod_template_path + if service_account_name is not None: spark_env.update( { @@ -419,7 +423,8 @@ def get_total_driver_memory_mb(spark_conf: Dict[str, str]) -> int: class SparkConfBuilder: - def __init__(self): + def __init__(self, is_driver_on_k8s_tron: bool = False): + self.is_driver_on_k8s_tron = is_driver_on_k8s_tron self.spark_srv_conf = dict() self.spark_constants = dict() self.default_spark_srv_conf = dict() @@ -628,7 +633,7 @@ def compute_executor_instances_k8s(self, user_spark_opts: Dict[str, str]) -> int ) # Deprecation message - if 'spark.cores.max' in user_spark_opts: + if not self.is_driver_on_k8s_tron and 'spark.cores.max' in user_spark_opts: log.warning( f'spark.cores.max is DEPRECATED. Replace with ' f'spark.executor.instances={executor_instances} in --spark-args and in your service code ' @@ -1105,20 +1110,22 @@ def get_spark_conf( # Pick a port from a pre-defined port range, which will then be used by our Jupyter # server metric aggregator API. The aggregator API collects Prometheus metrics from multiple # Spark sessions and exposes them through a single endpoint. - try: - ui_port = int( - (spark_opts_from_env or {}).get('spark.ui.port') or - utils.ephemeral_port_reserve_range( - self.spark_constants.get('preferred_spark_ui_port_start'), - self.spark_constants.get('preferred_spark_ui_port_end'), - ), - ) - except Exception as e: - log.warning( - f'Could not get an available port using srv-config port range: {e}. ' - 'Using default port range to get an available port.', - ) - ui_port = utils.ephemeral_port_reserve_range() + ui_port = self.spark_constants.get('preferred_spark_ui_port_start') + if not self.is_driver_on_k8s_tron: + try: + ui_port = int( + (spark_opts_from_env or {}).get('spark.ui.port') or + utils.ephemeral_port_reserve_range( + self.spark_constants.get('preferred_spark_ui_port_start'), + self.spark_constants.get('preferred_spark_ui_port_end'), + ), + ) + except Exception as e: + log.warning( + f'Could not get an available port using srv-config port range: {e}. ' + 'Using default port range to get an available port.', + ) + ui_port = utils.ephemeral_port_reserve_range() spark_conf = {**(spark_opts_from_env or {}), **_filter_user_spark_opts(user_spark_opts)} random_postfix = utils.get_random_string(4) @@ -1157,12 +1164,14 @@ def get_spark_conf( ) # Add pod template file - pod_template_path = utils.generate_pod_template_path() - try: - utils.create_pod_template(pod_template_path, app_base_name) - except Exception as e: - log.error(f'Failed to generate Spark executor pod template: {e}') - pod_template_path = '' + pod_template_path: Optional[str] = None + if not self.is_driver_on_k8s_tron: + pod_template_path = utils.generate_pod_template_path() + try: + utils.create_pod_template(pod_template_path, app_base_name) + except Exception as e: + log.error(f'Failed to generate Spark executor pod template: {e}') + pod_template_path = None if cluster_manager == 'kubernetes': spark_conf.update(_get_k8s_spark_env(