Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running parallel gridsearch via databricks-connect #24

Open
stelsemeyer opened this issue Aug 18, 2020 · 0 comments
Open

Running parallel gridsearch via databricks-connect #24

stelsemeyer opened this issue Aug 18, 2020 · 0 comments

Comments

@stelsemeyer
Copy link

stelsemeyer commented Aug 18, 2020

We are currently evaluating to use joblib-spark for parallel training/hyperparameter tuning of some sklearn models.
With a local instance of Spark the following code works charmingly, but with databricks-connect (6.5 and 7.1) it does not work. I have seen a similar issue was mentioned, but the commited fix did not help. I tried joblib-spark 0.2.0 and the version on master.
I guess when the pickled sklearn models are serialized and then reloaded something goes wrong.

Packages:
joblib==0.14.0
scikit-learn==0.21.3
databricks-connect==6.5.0 and 7.1.0

from sklearn.datasets import load_boston
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV

from sklearn.utils import parallel_backend
from joblibspark import register_spark

x, y = load_boston(return_X_y=True)

cv = GridSearchCV(estimator=GradientBoostingRegressor(), 
                  param_grid={'n_estimators': [50, 100, 150]})

with parallel_backend('spark', n_jobs=2):
    cv.fit(x, y)

raises

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-3-309e74709030> in <module>
      5 
      6 with parallel_backend('spark', n_jobs=2):
----> 7     cv.fit(x, y)
      8 
      9 cv.best_params_

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    734                 return results
    735 
--> 736             self._run_search(evaluate_candidates)
    737 
    738         # For multi-metric evaluation, store the best_index_, best_params_ and

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/model_selection/_search.py in _run_search(self, evaluate_candidates)
   1186     def _run_search(self, evaluate_candidates):
   1187         """Search all candidates in param_grid"""
-> 1188         evaluate_candidates(ParameterGrid(self.param_grid))
   1189 
   1190 

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/model_selection/_search.py in evaluate_candidates(candidate_params)
    713                                for parameters, (train, test)
    714                                in product(candidate_params,
--> 715                                           cv.split(X, y, groups)))
    716 
    717                 if len(out) < 1:

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable)
   1040 
   1041             with self._backend.retrieval_context():
-> 1042                 self.retrieve()
   1043             # Make sure that we get a last message telling us we are done
   1044             elapsed_time = time.time() - self._start_time

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblib/parallel.py in retrieve(self)
    921                     self._output.extend(job.get(timeout=self.timeout))
    922                 else:
--> 923                     self._output.extend(job.get())
    924 
    925             except BaseException as exception:

~/.pyenv/versions/3.7.8/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
    655             return self._value
    656         else:
--> 657             raise self._value
    658 
    659     def _set(self, i, obj):

~/.pyenv/versions/3.7.8/lib/python3.7/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
    119         job, i, func, args, kwds = task
    120         try:
--> 121             result = (True, func(*args, **kwds))
    122         except Exception as e:
    123             if wrap_exception and func is not _helper_reraises_exception:

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblib/_parallel_backends.py in __call__(self, *args, **kwargs)
    593     def __call__(self, *args, **kwargs):
    594         try:
--> 595             return self.func(*args, **kwargs)
    596         except KeyboardInterrupt as e:
    597             # We capture the KeyboardInterrupt and reraise it as

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblibspark/backend.py in run_on_worker_and_fetch_result()
    132                 ser_res = rdd.collect()[0]
    133             else:
--> 134                 ser_res = rdd.collectWithJobGroup(self._job_group, "joblib spark jobs")[0]
    135             return cloudpickle.loads(ser_res)
    136 

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/pyspark/rdd.py in collectWithJobGroup(self, groupId, description, interruptOnCancel)
    929         with SCCallSiteSync(self.context) as css:
    930             sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup(
--> 931                 self._jrdd.rdd(), groupId, description, interruptOnCancel)
    932         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    933 

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1303         answer = self.gateway_client.send_command(command)
   1304         return_value = get_return_value(
-> 1305             answer, self.gateway_client, self.target_id, self.name)
   1306 
   1307         for temp_arg in temp_args:

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
    126     def deco(*a, **kw):
    127         try:
--> 128             return f(*a, **kw)
    129         except py4j.protocol.Py4JJavaError as e:
    130             converted = convert_exception(e.java_exception)

~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServeWithJobGroup.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 21.0 failed 4 times, most recent failure: Lost task 0.3 in stage 21.0 (TID 87, 10.99.228.125, executor 0): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
    return self.loads(obj)
  File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
    return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 642, in main
    func, profiler, deserializer, serializer = read_command(pickleSer, infile)
  File "/databricks/spark/python/pyspark/worker.py", line 74, in read_command
    command = serializer._read_with_length(file)
  File "/databricks/spark/python/pyspark/serializers.py", line 180, in _read_with_length
    raise SerializationError("Caused by " + traceback.format_exc())
pyspark.serializers.SerializationError: Caused by Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
    return self.loads(obj)
  File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
    return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given


	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:585)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:720)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:703)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:538)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:941)
	at scala.collection.Iterator.foreach$(Iterator.scala:941)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1011)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2353)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
	at org.apache.spark.scheduler.Task.run(Task.scala:117)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$11(Executor.scala:657)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2476)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2425)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2424)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2424)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1129)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1129)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1129)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2676)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2623)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2611)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:915)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2313)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2334)
	at com.databricks.service.SparkServiceImpl$.$anonfun$executeRDD$3(SparkServiceImpl.scala:146)
	at org.apache.spark.SparkContext.withCallSite(SparkContext.scala:2283)
	at com.databricks.service.SparkServiceImpl$.$anonfun$executeRDD$1(SparkServiceImpl.scala:146)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperation$4(UsageLogging.scala:430)
	at com.databricks.logging.UsageLogging.$anonfun$withAttributionContext$1(UsageLogging.scala:238)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at com.databricks.logging.UsageLogging.withAttributionContext(UsageLogging.scala:233)
	at com.databricks.logging.UsageLogging.withAttributionContext$(UsageLogging.scala:230)
	at com.databricks.spark.util.PublicDBLogging.withAttributionContext(DatabricksSparkUsageLogger.scala:18)
	at com.databricks.logging.UsageLogging.withAttributionTags(UsageLogging.scala:275)
	at com.databricks.logging.UsageLogging.withAttributionTags$(UsageLogging.scala:268)
	at com.databricks.spark.util.PublicDBLogging.withAttributionTags(DatabricksSparkUsageLogger.scala:18)
	at com.databricks.logging.UsageLogging.recordOperation(UsageLogging.scala:411)
	at com.databricks.logging.UsageLogging.recordOperation$(UsageLogging.scala:337)
	at com.databricks.spark.util.PublicDBLogging.recordOperation(DatabricksSparkUsageLogger.scala:18)
	at com.databricks.spark.util.PublicDBLogging.recordOperation0(DatabricksSparkUsageLogger.scala:55)
	at com.databricks.spark.util.DatabricksSparkUsageLogger.recordOperation(DatabricksSparkUsageLogger.scala:98)
	at com.databricks.spark.util.UsageLogger.recordOperation(UsageLogger.scala:69)
	at com.databricks.spark.util.UsageLogger.recordOperation$(UsageLogger.scala:56)
	at com.databricks.spark.util.DatabricksSparkUsageLogger.recordOperation(DatabricksSparkUsageLogger.scala:67)
	at com.databricks.spark.util.UsageLogging.recordOperation(UsageLogger.scala:344)
	at com.databricks.spark.util.UsageLogging.recordOperation$(UsageLogger.scala:323)
	at com.databricks.service.SparkServiceImpl$.recordOperation(SparkServiceImpl.scala:91)
	at com.databricks.service.SparkServiceImpl$.executeRDD(SparkServiceImpl.scala:143)
	at com.databricks.service.SparkServiceRPCHandler.execute0(SparkServiceRPCHandler.scala:663)
	at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC0$1(SparkServiceRPCHandler.scala:439)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at com.databricks.service.SparkServiceRPCHandler.executeRPC0(SparkServiceRPCHandler.scala:346)
	at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:299)
	at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:285)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC$1(SparkServiceRPCHandler.scala:333)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at com.databricks.service.SparkServiceRPCHandler.executeRPC(SparkServiceRPCHandler.scala:311)
	at com.databricks.service.SparkServiceRPCServlet.doPost(SparkServiceRPCServer.scala:145)
	at javax.servlet.http.HttpServlet.service(HttpServlet.java:707)
	at javax.servlet.http.HttpServlet.service(HttpServlet.java:790)
	at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:873)
	at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:542)
	at org.eclipse.jetty.server.handler.ScopedHandler.nextScope(ScopedHandler.java:205)
	at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:480)
	at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:144)
	at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:132)
	at org.eclipse.jetty.server.Server.handle(Server.java:505)
	at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:370)
	at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:267)
	at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:305)
	at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:103)
	at org.eclipse.jetty.io.ChannelEndPoint$2.run(ChannelEndPoint.java:117)
	at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.runTask(EatWhatYouKill.java:333)
	at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.doProduce(EatWhatYouKill.java:310)
	at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.tryProduce(EatWhatYouKill.java:168)
	at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.run(EatWhatYouKill.java:126)
	at org.eclipse.jetty.util.thread.ReservedThreadExecutor$ReservedThread.run(ReservedThreadExecutor.java:366)
	at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:698)
	at org.eclipse.jetty.util.thread.QueuedThreadPool$Runner.run(QueuedThreadPool.java:804)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
    return self.loads(obj)
  File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
    return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 642, in main
    func, profiler, deserializer, serializer = read_command(pickleSer, infile)
  File "/databricks/spark/python/pyspark/worker.py", line 74, in read_command
    command = serializer._read_with_length(file)
  File "/databricks/spark/python/pyspark/serializers.py", line 180, in _read_with_length
    raise SerializationError("Caused by " + traceback.format_exc())
pyspark.serializers.SerializationError: Caused by Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
    return self.loads(obj)
  File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
    return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given


	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:585)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:720)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:703)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:538)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:941)
	at scala.collection.Iterator.foreach$(Iterator.scala:941)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1011)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2353)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
	at org.apache.spark.scheduler.Task.run(Task.scala:117)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$11(Executor.scala:657)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant