Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-wang-dl committed May 24, 2024
1 parent 7250ea9 commit ccde2d8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
20 changes: 6 additions & 14 deletions joblibspark/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@
from py4j.clientserver import ClientServer

import pyspark
from pyspark.sql import SparkSession
from pyspark import cloudpickle
from pyspark.util import VersionUtils

from .utils import create_resource_profile, get_spark_session


def register():
"""
Expand Down Expand Up @@ -76,10 +77,7 @@ def __init__(self,
super(SparkDistributedBackend, self).__init__(**backend_args)
self._pool = None
self._n_jobs = None
self._spark = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()
self._spark = get_spark_session()
self._spark_context = self._spark.sparkContext
self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4())
self._spark_pinned_threads_enabled = isinstance(
Expand Down Expand Up @@ -120,15 +118,9 @@ def _create_resource_profile(self,
if num_cpus_worker_node is not None and num_gpus_worker_node is not None:
if self._support_stage_scheduling:
self.using_stage_scheduling = True
#pylint: disable=import-outside-toplevel
from pyspark.resource.profile import ResourceProfileBuilder
from pyspark.resource.requests import TaskResourceRequests
#pylint: enable=import-outside-toplevel

task_res_req = TaskResourceRequests().cpus(num_cpus_worker_node)
if num_gpus_worker_node > 0:
task_res_req = task_res_req.resource("gpu", num_gpus_worker_node)
resource_profile = ResourceProfileBuilder().require(task_res_req).build

resource_profile = create_resource_profile(num_cpus_worker_node,
num_gpus_worker_node)

return resource_profile

Expand Down
51 changes: 51 additions & 0 deletions joblibspark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The utils functions for joblib spark backend.
"""


# pylint: disable=import-outside-toplevel
def get_spark_session():
"""
Get the spark session from the active session or create a new one.
"""
from pyspark.sql import SparkSession

spark_session = SparkSession.getActiveSession()
if spark_session is None:
spark_session = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()
return spark_session


def create_resource_profile(num_cpus_worker_node, num_gpus_worker_node):
"""
Create a resource profile for the task.
:param num_cpus_worker_node: Number of cpus for the Spark worker node.
:param num_gpus_worker_node: Number of gpus for the Spark worker node.
:return: Spark ResourceProfile
"""
from pyspark.resource.profile import ResourceProfileBuilder
from pyspark.resource.requests import TaskResourceRequests

task_res_req = TaskResourceRequests().cpus(num_cpus_worker_node)
if num_gpus_worker_node > 0:
task_res_req = task_res_req.resource("gpu", num_gpus_worker_node)
return ResourceProfileBuilder().require(task_res_req).build

0 comments on commit ccde2d8

Please sign in to comment.