You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are currently evaluating to use joblib-spark for parallel training/hyperparameter tuning of some sklearn models.
With a local instance of Spark the following code works charmingly, but with databricks-connect (6.5 and 7.1) it does not work. I have seen a similar issue was mentioned, but the commited fix did not help. I tried joblib-spark 0.2.0 and the version on master.
I guess when the pickled sklearn models are serialized and then reloaded something goes wrong.
Packages:
joblib==0.14.0
scikit-learn==0.21.3
databricks-connect==6.5.0 and 7.1.0
from sklearn.datasets import load_boston
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.utils import parallel_backend
from joblibspark import register_spark
x, y = load_boston(return_X_y=True)
cv = GridSearchCV(estimator=GradientBoostingRegressor(),
param_grid={'n_estimators': [50, 100, 150]})
with parallel_backend('spark', n_jobs=2):
cv.fit(x, y)
raises
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-3-309e74709030> in <module>
5
6 with parallel_backend('spark', n_jobs=2):
----> 7 cv.fit(x, y)
8
9 cv.best_params_
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
70 FutureWarning)
71 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72 return f(**kwargs)
73 return inner_f
74
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
734 return results
735
--> 736 self._run_search(evaluate_candidates)
737
738 # For multi-metric evaluation, store the best_index_, best_params_ and
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/model_selection/_search.py in _run_search(self, evaluate_candidates)
1186 def _run_search(self, evaluate_candidates):
1187 """Search all candidates in param_grid"""
-> 1188 evaluate_candidates(ParameterGrid(self.param_grid))
1189
1190
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/sklearn/model_selection/_search.py in evaluate_candidates(candidate_params)
713 for parameters, (train, test)
714 in product(candidate_params,
--> 715 cv.split(X, y, groups)))
716
717 if len(out) < 1:
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable)
1040
1041 with self._backend.retrieval_context():
-> 1042 self.retrieve()
1043 # Make sure that we get a last message telling us we are done
1044 elapsed_time = time.time() - self._start_time
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblib/parallel.py in retrieve(self)
921 self._output.extend(job.get(timeout=self.timeout))
922 else:
--> 923 self._output.extend(job.get())
924
925 except BaseException as exception:
~/.pyenv/versions/3.7.8/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
655 return self._value
656 else:
--> 657 raise self._value
658
659 def _set(self, i, obj):
~/.pyenv/versions/3.7.8/lib/python3.7/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
119 job, i, func, args, kwds = task
120 try:
--> 121 result = (True, func(*args, **kwds))
122 except Exception as e:
123 if wrap_exception and func is not _helper_reraises_exception:
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblib/_parallel_backends.py in __call__(self, *args, **kwargs)
593 def __call__(self, *args, **kwargs):
594 try:
--> 595 return self.func(*args, **kwargs)
596 except KeyboardInterrupt as e:
597 # We capture the KeyboardInterrupt and reraise it as
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/joblibspark/backend.py in run_on_worker_and_fetch_result()
132 ser_res = rdd.collect()[0]
133 else:
--> 134 ser_res = rdd.collectWithJobGroup(self._job_group, "joblib spark jobs")[0]
135 return cloudpickle.loads(ser_res)
136
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/pyspark/rdd.py in collectWithJobGroup(self, groupId, description, interruptOnCancel)
929 with SCCallSiteSync(self.context) as css:
930 sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup(
--> 931 self._jrdd.rdd(), groupId, description, interruptOnCancel)
932 return list(_load_from_socket(sock_info, self._jrdd_deserializer))
933
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/py4j/java_gateway.py in __call__(self, *args)
1303 answer = self.gateway_client.send_command(command)
1304 return_value = get_return_value(
-> 1305 answer, self.gateway_client, self.target_id, self.name)
1306
1307 for temp_arg in temp_args:
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
126 def deco(*a, **kw):
127 try:
--> 128 return f(*a, **kw)
129 except py4j.protocol.Py4JJavaError as e:
130 converted = convert_exception(e.java_exception)
~/.virtualenvs/databricks-spark/lib/python3.7/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServeWithJobGroup.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 21.0 failed 4 times, most recent failure: Lost task 0.3 in stage 21.0 (TID 87, 10.99.228.125, executor 0): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/databricks/spark/python/pyspark/worker.py", line 642, in main
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
File "/databricks/spark/python/pyspark/worker.py", line 74, in read_command
command = serializer._read_with_length(file)
File "/databricks/spark/python/pyspark/serializers.py", line 180, in _read_with_length
raise SerializationError("Caused by " + traceback.format_exc())
pyspark.serializers.SerializationError: Caused by Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:585)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:720)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:703)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:538)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1011)
at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2353)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
at org.apache.spark.scheduler.Task.run(Task.scala:117)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$11(Executor.scala:657)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2476)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2425)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2424)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2424)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1129)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1129)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1129)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2676)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2623)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2611)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:915)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2313)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2334)
at com.databricks.service.SparkServiceImpl$.$anonfun$executeRDD$3(SparkServiceImpl.scala:146)
at org.apache.spark.SparkContext.withCallSite(SparkContext.scala:2283)
at com.databricks.service.SparkServiceImpl$.$anonfun$executeRDD$1(SparkServiceImpl.scala:146)
at com.databricks.logging.UsageLogging.$anonfun$recordOperation$4(UsageLogging.scala:430)
at com.databricks.logging.UsageLogging.$anonfun$withAttributionContext$1(UsageLogging.scala:238)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at com.databricks.logging.UsageLogging.withAttributionContext(UsageLogging.scala:233)
at com.databricks.logging.UsageLogging.withAttributionContext$(UsageLogging.scala:230)
at com.databricks.spark.util.PublicDBLogging.withAttributionContext(DatabricksSparkUsageLogger.scala:18)
at com.databricks.logging.UsageLogging.withAttributionTags(UsageLogging.scala:275)
at com.databricks.logging.UsageLogging.withAttributionTags$(UsageLogging.scala:268)
at com.databricks.spark.util.PublicDBLogging.withAttributionTags(DatabricksSparkUsageLogger.scala:18)
at com.databricks.logging.UsageLogging.recordOperation(UsageLogging.scala:411)
at com.databricks.logging.UsageLogging.recordOperation$(UsageLogging.scala:337)
at com.databricks.spark.util.PublicDBLogging.recordOperation(DatabricksSparkUsageLogger.scala:18)
at com.databricks.spark.util.PublicDBLogging.recordOperation0(DatabricksSparkUsageLogger.scala:55)
at com.databricks.spark.util.DatabricksSparkUsageLogger.recordOperation(DatabricksSparkUsageLogger.scala:98)
at com.databricks.spark.util.UsageLogger.recordOperation(UsageLogger.scala:69)
at com.databricks.spark.util.UsageLogger.recordOperation$(UsageLogger.scala:56)
at com.databricks.spark.util.DatabricksSparkUsageLogger.recordOperation(DatabricksSparkUsageLogger.scala:67)
at com.databricks.spark.util.UsageLogging.recordOperation(UsageLogger.scala:344)
at com.databricks.spark.util.UsageLogging.recordOperation$(UsageLogger.scala:323)
at com.databricks.service.SparkServiceImpl$.recordOperation(SparkServiceImpl.scala:91)
at com.databricks.service.SparkServiceImpl$.executeRDD(SparkServiceImpl.scala:143)
at com.databricks.service.SparkServiceRPCHandler.execute0(SparkServiceRPCHandler.scala:663)
at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC0$1(SparkServiceRPCHandler.scala:439)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at com.databricks.service.SparkServiceRPCHandler.executeRPC0(SparkServiceRPCHandler.scala:346)
at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:299)
at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:285)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC$1(SparkServiceRPCHandler.scala:333)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at com.databricks.service.SparkServiceRPCHandler.executeRPC(SparkServiceRPCHandler.scala:311)
at com.databricks.service.SparkServiceRPCServlet.doPost(SparkServiceRPCServer.scala:145)
at javax.servlet.http.HttpServlet.service(HttpServlet.java:707)
at javax.servlet.http.HttpServlet.service(HttpServlet.java:790)
at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:873)
at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:542)
at org.eclipse.jetty.server.handler.ScopedHandler.nextScope(ScopedHandler.java:205)
at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:480)
at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:144)
at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:132)
at org.eclipse.jetty.server.Server.handle(Server.java:505)
at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:370)
at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:267)
at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:305)
at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:103)
at org.eclipse.jetty.io.ChannelEndPoint$2.run(ChannelEndPoint.java:117)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.runTask(EatWhatYouKill.java:333)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.doProduce(EatWhatYouKill.java:310)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.tryProduce(EatWhatYouKill.java:168)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.run(EatWhatYouKill.java:126)
at org.eclipse.jetty.util.thread.ReservedThreadExecutor$ReservedThread.run(ReservedThreadExecutor.java:366)
at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:698)
at org.eclipse.jetty.util.thread.QueuedThreadPool$Runner.run(QueuedThreadPool.java:804)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/databricks/spark/python/pyspark/worker.py", line 642, in main
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
File "/databricks/spark/python/pyspark/worker.py", line 74, in read_command
command = serializer._read_with_length(file)
File "/databricks/spark/python/pyspark/serializers.py", line 180, in _read_with_length
raise SerializationError("Caused by " + traceback.format_exc())
pyspark.serializers.SerializationError: Caused by Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 177, in _read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 466, in loads
return pickle.loads(obj, encoding=encoding)
TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:585)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:720)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:703)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:538)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1011)
at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2353)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
at org.apache.spark.scheduler.Task.run(Task.scala:117)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$11(Executor.scala:657)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
The text was updated successfully, but these errors were encountered:
We are currently evaluating to use
joblib-spark
for parallel training/hyperparameter tuning of some sklearn models.With a local instance of Spark the following code works charmingly, but with
databricks-connect
(6.5 and 7.1) it does not work. I have seen a similar issue was mentioned, but the commited fix did not help. I tried joblib-spark 0.2.0 and the version on master.I guess when the pickled sklearn models are serialized and then reloaded something goes wrong.
Packages:
joblib==0.14.0
scikit-learn==0.21.3
databricks-connect==6.5.0 and 7.1.0
raises
The text was updated successfully, but these errors were encountered: