diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 34a6556..4b7e937 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -37,10 +37,11 @@ from py4j.clientserver import ClientServer import pyspark -from pyspark.sql import SparkSession from pyspark import cloudpickle from pyspark.util import VersionUtils +from .utils import create_resource_profile, get_spark_session + def register(): """ @@ -76,10 +77,7 @@ def __init__(self, super(SparkDistributedBackend, self).__init__(**backend_args) self._pool = None self._n_jobs = None - self._spark = SparkSession \ - .builder \ - .appName("JoblibSparkBackend") \ - .getOrCreate() + self._spark = get_spark_session() self._spark_context = self._spark.sparkContext self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4()) self._spark_pinned_threads_enabled = isinstance( @@ -120,15 +118,9 @@ def _create_resource_profile(self, if num_cpus_worker_node is not None and num_gpus_worker_node is not None: if self._support_stage_scheduling: self.using_stage_scheduling = True - #pylint: disable=import-outside-toplevel - from pyspark.resource.profile import ResourceProfileBuilder - from pyspark.resource.requests import TaskResourceRequests - #pylint: enable=import-outside-toplevel - - task_res_req = TaskResourceRequests().cpus(num_cpus_worker_node) - if num_gpus_worker_node > 0: - task_res_req = task_res_req.resource("gpu", num_gpus_worker_node) - resource_profile = ResourceProfileBuilder().require(task_res_req).build + + resource_profile = create_resource_profile(num_cpus_worker_node, + num_gpus_worker_node) return resource_profile diff --git a/joblibspark/utils.py b/joblibspark/utils.py new file mode 100644 index 0000000..7545402 --- /dev/null +++ b/joblibspark/utils.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +The utils functions for joblib spark backend. +""" + + +# pylint: disable=import-outside-toplevel +def get_spark_session(): + """ + Get the spark session from the active session or create a new one. + """ + from pyspark.sql import SparkSession + + spark_session = SparkSession.getActiveSession() + if spark_session is None: + spark_session = SparkSession \ + .builder \ + .appName("JoblibSparkBackend") \ + .getOrCreate() + return spark_session + + +def create_resource_profile(num_cpus_worker_node, num_gpus_worker_node): + """ + Create a resource profile for the task. + :param num_cpus_worker_node: Number of cpus for the Spark worker node. + :param num_gpus_worker_node: Number of gpus for the Spark worker node. + :return: Spark ResourceProfile + """ + from pyspark.resource.profile import ResourceProfileBuilder + from pyspark.resource.requests import TaskResourceRequests + + task_res_req = TaskResourceRequests().cpus(num_cpus_worker_node) + if num_gpus_worker_node > 0: + task_res_req = task_res_req.resource("gpu", num_gpus_worker_node) + return ResourceProfileBuilder().require(task_res_req).build