diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index c91106b05d995..eb22bbd053aae 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -853,6 +853,11 @@ "Please fit or load a model smaller than bytes." ] }, + "MODEL_SUMMARY_LOST" : { + "message" : [ + "The model summary is lost because the cached model is offloaded." + ] + }, "UNSUPPORTED_EXCEPTION" : { "message" : [ "" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index cefa13b2bbe71..b653383161e74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -224,17 +224,8 @@ class FMClassifier @Since("3.0.0") ( factors: Matrix, objectiveHistory: Array[Double]): FMClassificationModel = { val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors)) - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val summary = new FMClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - model.setSummary(Some(summary)) + model.createSummary(dataset, objectiveHistory) + model } @Since("3.0.0") @@ -343,6 +334,42 @@ class FMClassificationModel private[classification] ( s"uid=${super.toString}, numClasses=$numClasses, numFeatures=$numFeatures, " + s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}" } + + private[spark] def createSummary( + dataset: Dataset[_], objectiveHistory: Array[Double] + ): Unit = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() + val summary = new FMClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + setSummary(Some(summary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( + path, Tuple1(summary.objectiveHistory), + (data, dos) => { + ReadWriteUtils.serializeDoubleArray(data._1, dos) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val Tuple1(objectiveHistory: Array[Double]) + = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( + path, + dis => { + Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) + } + ) + createSummary(dataset, objectiveHistory) + } } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index a50346ae88f4c..0d163b761686d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -277,17 +277,8 @@ class LinearSVC @Since("2.2.0") ( intercept: Double, objectiveHistory: Array[Double]): LinearSVCModel = { val model = copyValues(new LinearSVCModel(uid, coefficients, intercept)) - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() - val summary = new LinearSVCTrainingSummaryImpl( - summaryModel.transform(dataset), - rawPredictionColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - model.setSummary(Some(summary)) + model.createSummary(dataset, objectiveHistory) + model } private def trainImpl( @@ -445,6 +436,42 @@ class LinearSVCModel private[classification] ( override def toString: String = { s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures" } + + private[spark] def createSummary( + dataset: Dataset[_], objectiveHistory: Array[Double] + ): Unit = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, rawPredictionColName, predictionColName) = findSummaryModel() + val summary = new LinearSVCTrainingSummaryImpl( + summaryModel.transform(dataset), + rawPredictionColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + setSummary(Some(summary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( + path, Tuple1(summary.objectiveHistory), + (data, dos) => { + ReadWriteUtils.serializeDoubleArray(data._1, dos) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val Tuple1(objectiveHistory: Array[Double]) + = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( + path, + dis => { + Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) + } + ) + createSummary(dataset, objectiveHistory) + } } @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 58a2652d0eab9..8c010f67f5e0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -718,29 +718,8 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory: Array[Double]): LogisticRegressionModel = { val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, checkMultinomial(numClasses))) - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val logRegSummary = if (numClasses <= 2) { - new BinaryLogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } else { - new LogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } - model.setSummary(Some(logRegSummary)) + model.createSummary(dataset, objectiveHistory) + model } private def createBounds( @@ -1323,6 +1302,54 @@ class LogisticRegressionModel private[spark] ( override def toString: String = { s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures" } + + private[spark] def createSummary( + dataset: Dataset[_], objectiveHistory: Array[Double] + ): Unit = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() + val logRegSummary = if (numClasses <= 2) { + new BinaryLogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } else { + new LogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } + setSummary(Some(logRegSummary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( + path, Tuple1(summary.objectiveHistory), + (data, dos) => { + ReadWriteUtils.serializeDoubleArray(data._1, dos) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val Tuple1(objectiveHistory: Array[Double]) + = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( + path, + dis => { + Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) + } + ) + createSummary(dataset, objectiveHistory) + } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 6bd46cff815d7..5e52d62fb83cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -251,14 +251,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = { val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights)) - val (summaryModel, _, predictionColName) = model.findSummaryModel() - val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - "", - objectiveHistory) - model.setSummary(Some(summary)) + model.createSummary(dataset, objectiveHistory) + model } } @@ -365,6 +359,39 @@ class MultilayerPerceptronClassificationModel private[ml] ( s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${$(layers).length}, " + s"numClasses=$numClasses, numFeatures=$numFeatures" } + + private[spark] def createSummary( + dataset: Dataset[_], objectiveHistory: Array[Double] + ): Unit = { + val (summaryModel, _, predictionColName) = findSummaryModel() + val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + "", + objectiveHistory) + setSummary(Some(summary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( + path, Tuple1(summary.objectiveHistory), + (data, dos) => { + ReadWriteUtils.serializeDoubleArray(data._1, dos) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val Tuple1(objectiveHistory: Array[Double]) + = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( + path, + dis => { + Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) + } + ) + createSummary(dataset, objectiveHistory) + } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f64e2a6d4efc3..8b580b1e075c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -182,26 +182,8 @@ class RandomForestClassifier @Since("1.4.0") ( numFeatures: Int, numClasses: Int): RandomForestClassificationModel = { val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)) - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val rfSummary = if (numClasses <= 2) { - new BinaryRandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } else { - new RandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } - model.setSummary(Some(rfSummary)) + model.createSummary(dataset) + model } @Since("1.4.1") @@ -393,6 +375,35 @@ class RandomForestClassificationModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new RandomForestClassificationModel.RandomForestClassificationModelWriter(this) + + private[spark] def createSummary(dataset: Dataset[_]): Unit = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() + val rfSummary = if (numClasses <= 2) { + new BinaryRandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } else { + new RandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } + setSummary(Some(rfSummary)) + } + + override private[spark] def saveSummary(path: String): Unit = {} + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + createSummary(dataset) + } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 3248b4b391d0a..9e09ee00c3e30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -180,6 +180,9 @@ class BisectingKMeansModel private[ml] ( override def summary: BisectingKMeansSummary = super.summary override def estimatedSize: Long = SizeEstimator.estimate(parentModel) + + // BisectingKMeans model hasn't supported offloading, so put an empty `saveSummary` here for now + override private[spark] def saveSummary(path: String): Unit = {} } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a94b8a87d8fc7..e7f930065486b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -223,6 +223,36 @@ class GaussianMixtureModel private[ml] ( override def summary: GaussianMixtureSummary = super.summary override def estimatedSize: Long = SizeEstimator.estimate((weights, gaussians)) + + private[spark] def createSummary( + predictions: DataFrame, logLikelihood: Double, iteration: Int + ): Unit = { + val summary = new GaussianMixtureSummary(predictions, + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) + setSummary(Some(summary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[(Double, Int)]( + path, (summary.logLikelihood, summary.numIter), + (data, dos) => { + dos.writeDouble(data._1) + dos.writeInt(data._2) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val (logLikelihood: Double, numIter: Int) = ReadWriteUtils.loadObjectFromLocal[(Double, Int)]( + path, + dis => { + val logLikelihood = dis.readDouble() + val numIter = dis.readInt() + (logLikelihood, numIter) + } + ) + createSummary(dataset, logLikelihood, numIter) + } } @Since("2.0.0") @@ -453,11 +483,10 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)) .setParent(this) - val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) + model.createSummary(model.transform(dataset), logLikelihood, iteration) instr.logNamedValue("logLikelihood", logLikelihood) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - model.setSummary(Some(summary)) + instr.logNamedValue("clusterSizes", model.summary.clusterSizes) + model } private def trainImpl( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f3ac58e670e5a..ccae39cedd20f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -215,6 +215,42 @@ class KMeansModel private[ml] ( override def summary: KMeansSummary = super.summary override def estimatedSize: Long = SizeEstimator.estimate(parentModel.clusterCenters) + + private[spark] def createSummary( + predictions: DataFrame, numIter: Int, trainingCost: Double + ): Unit = { + val summary = new KMeansSummary( + predictions, + $(predictionCol), + $(featuresCol), + $(k), + numIter, + trainingCost) + + setSummary(Some(summary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[(Int, Double)]( + path, (summary.numIter, summary.trainingCost), + (data, dos) => { + dos.writeInt(data._1) + dos.writeDouble(data._2) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val (numIter: Int, trainingCost: Double) = ReadWriteUtils.loadObjectFromLocal[(Int, Double)]( + path, + dis => { + val numIter = dis.readInt() + val trainingCost = dis.readDouble() + (numIter, trainingCost) + } + ) + createSummary(dataset, numIter, trainingCost) + } } /** Helper class for storing model data */ @@ -414,16 +450,9 @@ class KMeans @Since("1.5.0") ( } val model = copyValues(new KMeansModel(uid, oldModel).setParent(this)) - val summary = new KMeansSummary( - model.transform(dataset), - $(predictionCol), - $(featuresCol), - $(k), - oldModel.numIter, - oldModel.trainingCost) - model.setSummary(Some(summary)) - instr.logNamedValue("clusterSizes", summary.clusterSizes) + model.createSummary(model.transform(dataset), oldModel.numIter, oldModel.trainingCost) + instr.logNamedValue("clusterSizes", model.summary.clusterSizes) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 14467c761b216..cf62c2bf41b6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -419,9 +419,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, - wlsModel.diagInvAtWA.toArray, 1, getSolver) - model.setSummary(Some(trainingSummary)) + model.createSummary(dataset, wlsModel.diagInvAtWA.toArray, 1) + model } else { val instances = validated.rdd.map { case Row(label: Double, weight: Double, offset: Double, features: Vector) => @@ -436,9 +435,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) - val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, - irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(Some(trainingSummary)) + model.createSummary(dataset, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations) + model } model @@ -1140,6 +1138,39 @@ class GeneralizedLinearRegressionModel private[ml] ( s"GeneralizedLinearRegressionModel: uid=$uid, family=${$(family)}, link=${$(link)}, " + s"numFeatures=$numFeatures" } + + private[spark] def createSummary( + dataset: Dataset[_], diagInvAtWA: Array[Double], numIter: Int + ): Unit = { + val summary = new GeneralizedLinearRegressionTrainingSummary( + dataset, this, diagInvAtWA, numIter, $(solver) + ) + + setSummary(Some(summary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[(Array[Double], Int)]( + path, (summary.diagInvAtWA, summary.numIterations), + (data, dos) => { + ReadWriteUtils.serializeDoubleArray(data._1, dos) + dos.writeInt(data._2) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val (diagInvAtWA: Array[Double], numIterations: Int) = + ReadWriteUtils.loadObjectFromLocal[(Array[Double], Int)]( + path, + dis => { + val diagInvAtWA = ReadWriteUtils.deserializeDoubleArray(dis) + val numIterations = dis.readInt() + (diagInvAtWA, numIterations) + } + ) + createSummary(dataset, diagInvAtWA, numIterations) + } } @Since("2.0.0") @@ -1467,7 +1498,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( class GeneralizedLinearRegressionTrainingSummary private[regression] ( dataset: Dataset[_], origModel: GeneralizedLinearRegressionModel, - private val diagInvAtWA: Array[Double], + private[spark] val diagInvAtWA: Array[Double], @Since("2.0.0") val numIterations: Int, @Since("2.0.0") val solver: String) extends GeneralizedLinearRegressionSummary(dataset, origModel) with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b06140e48338c..822df270c0bf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -433,15 +433,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd) - - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - summaryModel.get(summaryModel.weightCol).getOrElse(""), - summaryModel.numFeatures, summaryModel.getFitIntercept, - Array(0.0), objectiveHistory) - model.setSummary(Some(trainingSummary)) + model.createSummary(dataset, Array(0.0), objectiveHistory, Array.emptyDoubleArray) + model } private def trainWithNormal( @@ -459,20 +452,16 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // attach returned model. val lrModel = copyValues(new LinearRegressionModel( uid, model.coefficients.compressed, model.intercept)) - val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() - val coefficientArray = if (summaryModel.getFitIntercept) { - summaryModel.coefficients.toArray ++ Array(summaryModel.intercept) + val coefficientArray = if (lrModel.getFitIntercept) { + lrModel.coefficients.toArray ++ Array(lrModel.intercept) } else { - summaryModel.coefficients.toArray + lrModel.coefficients.toArray } - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - summaryModel.get(summaryModel.weightCol).getOrElse(""), - summaryModel.numFeatures, summaryModel.getFitIntercept, - model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray) - - lrModel.setSummary(Some(trainingSummary)) + lrModel.createSummary( + dataset, model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray + ) + lrModel } private def trainWithConstantLabel( @@ -497,16 +486,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val intercept = yMean val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - summaryModel.get(summaryModel.weightCol).getOrElse(""), - summaryModel.numFeatures, summaryModel.getFitIntercept, - Array(0.0), Array(0.0)) - - model.setSummary(Some(trainingSummary)) + model.createSummary(dataset, Array(0.0), Array(0.0), Array.emptyDoubleArray) + model } private def createOptimizer( @@ -800,6 +782,53 @@ class LinearRegressionModel private[ml] ( override def toString: String = { s"LinearRegressionModel: uid=$uid, numFeatures=$numFeatures" } + + private[spark] def createSummary( + dataset: Dataset[_], + diagInvAtWA: Array[Double], + objectiveHistory: Array[Double], + coefficientArray: Array[Double] + ): Unit = { + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + diagInvAtWA, objectiveHistory, coefficientArray) + + setSummary(Some(trainingSummary)) + } + + override private[spark] def saveSummary(path: String): Unit = { + ReadWriteUtils.saveObjectToLocal[(Array[Double], Array[Double], Array[Double])]( + path, (summary.diagInvAtWA, summary.objectiveHistory, summary.coefficientArray), + (data, dos) => { + ReadWriteUtils.serializeDoubleArray(data._1, dos) + ReadWriteUtils.serializeDoubleArray(data._2, dos) + ReadWriteUtils.serializeDoubleArray(data._3, dos) + } + ) + } + + override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + val ( + diagInvAtWA: Array[Double], + objectiveHistory: Array[Double], + coefficientArray: Array[Double] + ) + = ReadWriteUtils.loadObjectFromLocal[(Array[Double], Array[Double], Array[Double])]( + path, + dis => { + val diagInvAtWA = ReadWriteUtils.deserializeDoubleArray(dis) + val objectiveHistory = ReadWriteUtils.deserializeDoubleArray(dis) + val coefficientArray = ReadWriteUtils.deserializeDoubleArray(dis) + (diagInvAtWA, objectiveHistory, coefficientArray) + } + ) + createSummary(dataset, diagInvAtWA, objectiveHistory, coefficientArray) + } } private[ml] case class LinearModelData(intercept: Double, coefficients: Vector, scale: Double) @@ -926,7 +955,7 @@ class LinearRegressionTrainingSummary private[regression] ( private val fitIntercept: Boolean, diagInvAtWA: Array[Double], val objectiveHistory: Array[Double], - private val coefficientArray: Array[Double] = Array.emptyDoubleArray) + override private[regression] val coefficientArray: Array[Double] = Array.emptyDoubleArray) extends LinearRegressionSummary( predictions, predictionCol, @@ -972,8 +1001,8 @@ class LinearRegressionSummary private[regression] ( private val weightCol: String, private val numFeatures: Int, private val fitIntercept: Boolean, - private val diagInvAtWA: Array[Double], - private val coefficientArray: Array[Double] = Array.emptyDoubleArray) + private[regression] val diagInvAtWA: Array[Double], + private[regression] val coefficientArray: Array[Double] = Array.emptyDoubleArray) extends Summary with Serializable { @transient private val metrics = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala index 0ba8ce072ab4a..c6f6babf71a2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.util import org.apache.spark.SparkException import org.apache.spark.annotation.Since +import org.apache.spark.sql.DataFrame /** @@ -49,4 +50,14 @@ private[spark] trait HasTrainingSummary[T] { this.trainingSummary = summary this } + + private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { + throw new SparkException( + s"No loadSummary implementation for this ${this.getClass.getSimpleName}") + } + + private[spark] def saveSummary(path: String): Unit = { + throw new SparkException( + s"No saveSummary implementation for this ${this.getClass.getSimpleName}") + } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index a5fdaed0db2c4..f66fc762971b5 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -889,15 +889,14 @@ def summary(self) -> "LinearSVCTrainingSummary": # type: ignore[override] Gets summary (accuracy/precision/recall, objective history, total iterations) of model trained on the training set. An exception is thrown if `trainingSummary is None`. """ - if self.hasSummary: - s = LinearSVCTrainingSummary(super(LinearSVCModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) + return super().summary + + @property + def _summaryCls(self) -> type: + return LinearSVCTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary": """ @@ -1577,29 +1576,6 @@ def interceptVector(self) -> Vector: """ return self._call_java("interceptVector") - @property - @since("2.0.0") - def summary(self) -> "LogisticRegressionTrainingSummary": - """ - Gets summary (accuracy/precision/recall, objective history, total iterations) of model - trained on the training set. An exception is thrown if `trainingSummary is None`. - """ - if self.hasSummary: - s: LogisticRegressionTrainingSummary - if self.numClasses <= 2: - s = BinaryLogisticRegressionTrainingSummary( - super(LogisticRegressionModel, self).summary - ) - else: - s = LogisticRegressionTrainingSummary(super(LogisticRegressionModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) - def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary": """ Evaluates the model on a test dataset. @@ -1623,6 +1599,15 @@ def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary": s.__source_transformer__ = self # type: ignore[attr-defined] return s + @property + def _summaryCls(self) -> type: + if self.numClasses <= 2: + return BinaryLogisticRegressionTrainingSummary + return LogisticRegressionTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset + class LogisticRegressionSummary(_ClassificationSummary): """ @@ -2315,29 +2300,13 @@ def trees(self) -> List[DecisionTreeClassificationModel]: return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] @property - @since("3.1.0") - def summary(self) -> "RandomForestClassificationTrainingSummary": - """ - Gets summary (accuracy/precision/recall, objective history, total iterations) of model - trained on the training set. An exception is thrown if `trainingSummary is None`. - """ - if self.hasSummary: - s: RandomForestClassificationTrainingSummary - if self.numClasses <= 2: - s = BinaryRandomForestClassificationTrainingSummary( - super(RandomForestClassificationModel, self).summary - ) - else: - s = RandomForestClassificationTrainingSummary( - super(RandomForestClassificationModel, self).summary - ) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) + def _summaryCls(self) -> type: + if self.numClasses <= 2: + return BinaryRandomForestClassificationTrainingSummary + return RandomForestClassificationTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset def evaluate(self, dataset: DataFrame) -> "RandomForestClassificationSummary": """ @@ -3372,17 +3341,14 @@ def summary( # type: ignore[override] Gets summary (accuracy/precision/recall, objective history, total iterations) of model trained on the training set. An exception is thrown if `trainingSummary is None`. """ - if self.hasSummary: - s = MultilayerPerceptronClassificationTrainingSummary( - super(MultilayerPerceptronClassificationModel, self).summary - ) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) + return super().summary + + @property + def _summaryCls(self) -> type: + return MultilayerPerceptronClassificationTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset def evaluate(self, dataset: DataFrame) -> "MultilayerPerceptronClassificationSummary": """ @@ -4321,22 +4287,6 @@ def factors(self) -> Matrix: """ return self._call_java("factors") - @since("3.1.0") - def summary(self) -> "FMClassificationTrainingSummary": - """ - Gets summary (accuracy/precision/recall, objective history, total iterations) of model - trained on the training set. An exception is thrown if `trainingSummary is None`. - """ - if self.hasSummary: - s = FMClassificationTrainingSummary(super(FMClassificationModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) - def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary": """ Evaluates the model on a test dataset. @@ -4356,6 +4306,21 @@ def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary": s.__source_transformer__ = self # type: ignore[attr-defined] return s + @since("3.1.0") + def summary(self) -> "FMClassificationTrainingSummary": + """ + Gets summary (accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + return super().summary + + @property + def _summaryCls(self) -> type: + return FMClassificationTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset + class FMClassificationSummary(_BinaryClassificationSummary): """ diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7267ee2805987..0e26398de3c45 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -255,23 +255,6 @@ def gaussiansDF(self) -> DataFrame: """ return self._call_java("gaussiansDF") - @property - @since("2.1.0") - def summary(self) -> "GaussianMixtureSummary": - """ - Gets summary (cluster assignments, cluster sizes) of the model trained on the - training set. An exception is thrown if no summary exists. - """ - if self.hasSummary: - s = GaussianMixtureSummary(super(GaussianMixtureModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) - @since("3.0.0") def predict(self, value: Vector) -> int: """ @@ -286,6 +269,10 @@ def predictProbability(self, value: Vector) -> Vector: """ return self._call_java("predictProbability", value) + @property + def _summaryCls(self) -> type: + return GaussianMixtureSummary + @inherit_doc class GaussianMixture( @@ -705,23 +692,6 @@ def numFeatures(self) -> int: """ return self._call_java("numFeatures") - @property - @since("2.1.0") - def summary(self) -> KMeansSummary: - """ - Gets summary (cluster assignments, cluster sizes) of the model trained on the - training set. An exception is thrown if no summary exists. - """ - if self.hasSummary: - s = KMeansSummary(super(KMeansModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) - @since("3.0.0") def predict(self, value: Vector) -> int: """ @@ -729,6 +699,10 @@ def predict(self, value: Vector) -> int: """ return self._call_java("predict", value) + @property + def _summaryCls(self) -> type: + return KMeansSummary + @inherit_doc class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable["KMeans"]): @@ -1055,23 +1029,6 @@ def numFeatures(self) -> int: """ return self._call_java("numFeatures") - @property - @since("2.1.0") - def summary(self) -> "BisectingKMeansSummary": - """ - Gets summary (cluster assignments, cluster sizes) of the model trained on the - training set. An exception is thrown if no summary exists. - """ - if self.hasSummary: - s = BisectingKMeansSummary(super(BisectingKMeansModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) - @since("3.0.0") def predict(self, value: Vector) -> int: """ @@ -1079,6 +1036,10 @@ def predict(self, value: Vector) -> int: """ return self._call_java("predict", value) + @property + def _summaryCls(self) -> type: + return BisectingKMeansSummary + @inherit_doc class BisectingKMeans( diff --git a/python/pyspark/ml/connect/proto.py b/python/pyspark/ml/connect/proto.py index 31f100859281a..7cffd32631ba5 100644 --- a/python/pyspark/ml/connect/proto.py +++ b/python/pyspark/ml/connect/proto.py @@ -70,8 +70,13 @@ class AttributeRelation(LogicalPlan): could be a model or a summary. This attribute returns a DataFrame. """ - def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None: - super().__init__(None) + def __init__( + self, + ref_id: str, + methods: List[pb2.Fetch.Method], + child: Optional["LogicalPlan"] = None, + ) -> None: + super().__init__(child) self._ref_id = ref_id self._methods = methods @@ -79,4 +84,6 @@ def plan(self, session: "SparkConnectClient") -> pb2.Relation: plan = self._create_proto_relation() plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id)) plan.ml_relation.fetch.methods.extend(self._methods) + if self._child is not None: + plan.ml_relation.model_summary_dataset.CopyFrom(self._child.plan(session)) return plan diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 66d6dbd6a2678..ce97b98f6665c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -479,22 +479,11 @@ def scale(self) -> float: return self._call_java("scale") @property - @since("2.0.0") - def summary(self) -> "LinearRegressionTrainingSummary": - """ - Gets summary (residuals, MSE, r-squared ) of model on - training set. An exception is thrown if - `trainingSummary is None`. - """ - if self.hasSummary: - s = LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) + def _summaryCls(self) -> type: + return LinearRegressionTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary": """ @@ -2774,24 +2763,11 @@ def intercept(self) -> float: return self._call_java("intercept") @property - @since("2.0.0") - def summary(self) -> "GeneralizedLinearRegressionTrainingSummary": - """ - Gets summary (residuals, deviance, p-values) of model on - training set. An exception is thrown if - `trainingSummary is None`. - """ - if self.hasSummary: - s = GeneralizedLinearRegressionTrainingSummary( - super(GeneralizedLinearRegressionModel, self).summary - ) - if is_remote(): - s.__source_transformer__ = self # type: ignore[attr-defined] - return s - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) + def _summaryCls(self) -> type: + return GeneralizedLinearRegressionTrainingSummary + + def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: + return train_dataset def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary": """ diff --git a/python/pyspark/ml/tests/connect/test_connect_cache.py b/python/pyspark/ml/tests/connect/test_connect_cache.py index 8d156a0f11e1d..f911ab22286c0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_cache.py +++ b/python/pyspark/ml/tests/connect/test_connect_cache.py @@ -48,20 +48,24 @@ def test_delete_model(self): "obj: class org.apache.spark.ml.classification.LinearSVCModel" in cache_info[0], cache_info, ) - assert model._java_obj._ref_count == 1 + # the `model._summary` holds another ref to the remote model. + assert model._java_obj._ref_count == 2 model2 = model.copy() cache_info = spark.client._get_ml_cache_info() self.assertEqual(len(cache_info), 1) - assert model._java_obj._ref_count == 2 - assert model2._java_obj._ref_count == 2 + assert model._java_obj._ref_count == 3 + assert model2._java_obj._ref_count == 3 # explicitly delete the model del model cache_info = spark.client._get_ml_cache_info() self.assertEqual(len(cache_info), 1) - assert model2._java_obj._ref_count == 1 + # Note the copied model 'model2' also holds the `_summary` object, + # and the `_summary` object holds another ref to the remote model. + # so the ref count is 2. + assert model2._java_obj._ref_count == 2 del model2 cache_info = spark.client._get_ml_cache_info() @@ -99,7 +103,6 @@ def test_cleanup_ml_cache(self): cache_info, ) - # explicitly delete the model1 del model1 cache_info = spark.client._get_ml_cache_info() diff --git a/python/pyspark/ml/tests/test_classification.py b/python/pyspark/ml/tests/test_classification.py index 57e4c0ef86dc0..21bce70e8735b 100644 --- a/python/pyspark/ml/tests/test_classification.py +++ b/python/pyspark/ml/tests/test_classification.py @@ -55,6 +55,7 @@ MultilayerPerceptronClassificationTrainingSummary, ) from pyspark.ml.regression import DecisionTreeRegressionModel +from pyspark.sql import is_remote from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -241,37 +242,45 @@ def test_binary_logistic_regression_summary(self): model = lr.fit(df) self.assertEqual(lr.uid, model.uid) self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertTrue(isinstance(s.roc, DataFrame)) - self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) - self.assertTrue(isinstance(s.pr, DataFrame)) - self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) - self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) - self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) - self.assertAlmostEqual(s.accuracy, 1.0, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) - self.assertAlmostEqual(s.weightedRecall, 1.0, 2) - self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + def check_summary(): + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() + + s = model.summary # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) @@ -292,31 +301,39 @@ def test_multiclass_logistic_regression_summary(self): lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) model = lr.fit(df) self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertAlmostEqual(s.accuracy, 0.75, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) - self.assertAlmostEqual(s.weightedRecall, 0.75, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + def check_summary(): + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 0.75, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() + + s = model.summary # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) @@ -426,15 +443,21 @@ def test_linear_svc(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - # model summary - self.assertTrue(model.hasSummary) - summary = model.summary() - self.assertIsInstance(summary, LinearSVCSummary) - self.assertIsInstance(summary, LinearSVCTrainingSummary) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.5) - self.assertEqual(summary.areaUnderROC, 0.75) - self.assertEqual(summary.predictions.columns, expected_cols) + def check_summary(): + # model summary + self.assertTrue(model.hasSummary) + summary = model.summary() + self.assertIsInstance(summary, LinearSVCSummary) + self.assertIsInstance(summary, LinearSVCTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.5) + self.assertEqual(summary.areaUnderROC, 0.75) + self.assertEqual(summary.predictions.columns, expected_cols) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() summary2 = model.evaluate(df) self.assertIsInstance(summary2, LinearSVCSummary) @@ -526,13 +549,20 @@ def test_factorization_machine(self): # model summary self.assertTrue(model.hasSummary) - summary = model.summary() - self.assertIsInstance(summary, FMClassificationSummary) - self.assertIsInstance(summary, FMClassificationTrainingSummary) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.25) - self.assertEqual(summary.areaUnderROC, 0.5) - self.assertEqual(summary.predictions.columns, expected_cols) + + def check_summary(): + summary = model.summary() + self.assertIsInstance(summary, FMClassificationSummary) + self.assertIsInstance(summary, FMClassificationTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.25) + self.assertEqual(summary.areaUnderROC, 0.5) + self.assertEqual(summary.predictions.columns, expected_cols) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() summary2 = model.evaluate(df) self.assertIsInstance(summary2, FMClassificationSummary) @@ -773,21 +803,27 @@ def test_binary_random_forest_classifier(self): self.assertEqual(tree.transform(df).count(), 4) self.assertEqual(tree.transform(df).columns, expected_cols) - # model summary - summary = model.summary - self.assertTrue(isinstance(summary, BinaryRandomForestClassificationSummary)) - self.assertTrue(isinstance(summary, BinaryRandomForestClassificationTrainingSummary)) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.75) - self.assertEqual(summary.areaUnderROC, 0.875) - self.assertEqual(summary.predictions.columns, expected_cols) + def check_summary(): + # model summary + summary = model.summary + self.assertTrue(isinstance(summary, BinaryRandomForestClassificationSummary)) + self.assertTrue(isinstance(summary, BinaryRandomForestClassificationTrainingSummary)) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.75) + self.assertEqual(summary.areaUnderROC, 0.875) + self.assertEqual(summary.predictions.columns, expected_cols) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() summary2 = model.evaluate(df) self.assertTrue(isinstance(summary2, BinaryRandomForestClassificationSummary)) self.assertFalse(isinstance(summary2, BinaryRandomForestClassificationTrainingSummary)) self.assertEqual(summary2.labels, [0.0, 1.0]) self.assertEqual(summary2.accuracy, 0.75) - self.assertEqual(summary.areaUnderROC, 0.875) + self.assertEqual(summary2.areaUnderROC, 0.875) self.assertEqual(summary2.predictions.columns, expected_cols) # Model save & load @@ -859,13 +895,19 @@ def test_multiclass_random_forest_classifier(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - # model summary - summary = model.summary - self.assertTrue(isinstance(summary, RandomForestClassificationSummary)) - self.assertTrue(isinstance(summary, RandomForestClassificationTrainingSummary)) - self.assertEqual(summary.labels, [0.0, 1.0, 2.0]) - self.assertEqual(summary.accuracy, 0.5) - self.assertEqual(summary.predictions.columns, expected_cols) + def check_summary(): + # model summary + summary = model.summary + self.assertTrue(isinstance(summary, RandomForestClassificationSummary)) + self.assertTrue(isinstance(summary, RandomForestClassificationTrainingSummary)) + self.assertEqual(summary.labels, [0.0, 1.0, 2.0]) + self.assertEqual(summary.accuracy, 0.5) + self.assertEqual(summary.predictions.columns, expected_cols) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() summary2 = model.evaluate(df) self.assertTrue(isinstance(summary2, RandomForestClassificationSummary)) @@ -953,14 +995,20 @@ def test_mlp(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - # model summary - self.assertTrue(model.hasSummary) - summary = model.summary() - self.assertIsInstance(summary, MultilayerPerceptronClassificationSummary) - self.assertIsInstance(summary, MultilayerPerceptronClassificationTrainingSummary) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.75) - self.assertEqual(summary.predictions.columns, expected_cols) + def check_summary(): + # model summary + self.assertTrue(model.hasSummary) + summary = model.summary() + self.assertIsInstance(summary, MultilayerPerceptronClassificationSummary) + self.assertIsInstance(summary, MultilayerPerceptronClassificationTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.75) + self.assertEqual(summary.predictions.columns, expected_cols) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() summary2 = model.evaluate(df) self.assertIsInstance(summary2, MultilayerPerceptronClassificationSummary) diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index 1b8eb73135a96..fbf012babcc3d 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -85,23 +85,39 @@ def test_kmeans(self): self.assertTrue(np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 1, atol=1e-4)) - # Model summary - self.assertTrue(model.hasSummary) - summary = model.summary - self.assertTrue(isinstance(summary, KMeansSummary)) - self.assertEqual(summary.k, 2) - self.assertEqual(summary.numIter, 2) - self.assertEqual(summary.clusterSizes, [4, 2]) - self.assertTrue(np.allclose(summary.trainingCost, 1.35710375, atol=1e-4)) + def check_summary(): + # Model summary + self.assertTrue(model.hasSummary) + summary = model.summary + self.assertTrue(isinstance(summary, KMeansSummary)) + self.assertEqual(summary.k, 2) + self.assertEqual(summary.numIter, 2) + self.assertEqual(summary.clusterSizes, [4, 2]) + self.assertTrue(np.allclose(summary.trainingCost, 1.35710375, atol=1e-4)) - self.assertEqual(summary.featuresCol, "features") - self.assertEqual(summary.predictionCol, "prediction") + self.assertEqual(summary.featuresCol, "features") + self.assertEqual(summary.predictionCol, "prediction") - self.assertEqual(summary.cluster.columns, ["prediction"]) - self.assertEqual(summary.cluster.count(), 6) + self.assertEqual(summary.cluster.columns, ["prediction"]) + self.assertEqual(summary.cluster.count(), 6) - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 6) + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 6) + + # check summary before model offloading occurs + check_summary() + + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + # check summary "try_remote_call" path after model offloading occurs + self.assertEqual(model.summary.numIter, 2) + + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + # check summary "invoke_remote_attribute_relation" path after model offloading occurs + self.assertEqual(model.summary.cluster.count(), 6) + + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() # save & load with tempfile.TemporaryDirectory(prefix="kmeans_model") as d: @@ -112,6 +128,9 @@ def test_kmeans(self): model.write().overwrite().save(d) model2 = KMeansModel.load(d) self.assertEqual(str(model), str(model2)) + self.assertFalse(model2.hasSummary) + with self.assertRaisesRegex(Exception, "No training summary available"): + model2.summary def test_bisecting_kmeans(self): df = ( @@ -278,30 +297,36 @@ def test_gaussian_mixture(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 6) - # Model summary - self.assertTrue(model.hasSummary) - summary = model.summary - self.assertTrue(isinstance(summary, GaussianMixtureSummary)) - self.assertEqual(summary.k, 2) - self.assertEqual(summary.numIter, 2) - self.assertEqual(len(summary.clusterSizes), 2) - self.assertEqual(summary.clusterSizes, [3, 3]) - ll = summary.logLikelihood - self.assertTrue(ll < 0, ll) - self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll) - - self.assertEqual(summary.featuresCol, "features") - self.assertEqual(summary.predictionCol, "prediction") - self.assertEqual(summary.probabilityCol, "probability") - - self.assertEqual(summary.cluster.columns, ["prediction"]) - self.assertEqual(summary.cluster.count(), 6) - - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 6) - - self.assertEqual(summary.probability.columns, ["probability"]) - self.assertEqual(summary.predictions.count(), 6) + def check_summary(): + # Model summary + self.assertTrue(model.hasSummary) + summary = model.summary + self.assertTrue(isinstance(summary, GaussianMixtureSummary)) + self.assertEqual(summary.k, 2) + self.assertEqual(summary.numIter, 2) + self.assertEqual(len(summary.clusterSizes), 2) + self.assertEqual(summary.clusterSizes, [3, 3]) + ll = summary.logLikelihood + self.assertTrue(ll < 0, ll) + self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll) + + self.assertEqual(summary.featuresCol, "features") + self.assertEqual(summary.predictionCol, "prediction") + self.assertEqual(summary.probabilityCol, "probability") + + self.assertEqual(summary.cluster.columns, ["prediction"]) + self.assertEqual(summary.cluster.count(), 6) + + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 6) + + self.assertEqual(summary.probability.columns, ["probability"]) + self.assertEqual(summary.predictions.count(), 6) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() # save & load with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d: diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index 8638fb4d6078e..52688fdd63cf2 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -43,6 +43,7 @@ GBTRegressor, GBTRegressionModel, ) +from pyspark.sql import is_remote from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -193,50 +194,58 @@ def test_linear_regression(self): np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 0.21249999999999963, atol=1e-4) ) - # Model summary - summary = model.summary - self.assertTrue(isinstance(summary, LinearRegressionSummary)) - self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary)) - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 4) - self.assertEqual(summary.residuals.columns, ["residuals"]) - self.assertEqual(summary.residuals.count(), 4) - - self.assertEqual(summary.degreesOfFreedom, 1) - self.assertEqual(summary.numInstances, 4) - self.assertEqual(summary.objectiveHistory, [0.0]) - self.assertTrue( - np.allclose( - summary.coefficientStandardErrors, - [1.2859821149611763, 0.6248749874975031, 3.1645497310044184], - atol=1e-4, + def check_summary(): + # Model summary + summary = model.summary + self.assertTrue(isinstance(summary, LinearRegressionSummary)) + self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary)) + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 4) + self.assertEqual(summary.residuals.columns, ["residuals"]) + self.assertEqual(summary.residuals.count(), 4) + + self.assertEqual(summary.degreesOfFreedom, 1) + self.assertEqual(summary.numInstances, 4) + self.assertEqual(summary.objectiveHistory, [0.0]) + self.assertTrue( + np.allclose( + summary.coefficientStandardErrors, + [1.2859821149611763, 0.6248749874975031, 3.1645497310044184], + atol=1e-4, + ) ) - ) - self.assertTrue( - np.allclose( - summary.devianceResiduals, [-0.7424621202458727, 0.7875000000000003], atol=1e-4 + self.assertTrue( + np.allclose( + summary.devianceResiduals, [-0.7424621202458727, 0.7875000000000003], atol=1e-4 + ) ) - ) - self.assertTrue( - np.allclose( - summary.pValues, - [0.7020630236843428, 0.8866003086182783, 0.9298746994547682], - atol=1e-4, + self.assertTrue( + np.allclose( + summary.pValues, + [0.7020630236843428, 0.8866003086182783, 0.9298746994547682], + atol=1e-4, + ) ) - ) - self.assertTrue( - np.allclose( - summary.tValues, - [0.5054502643838291, 0.1800360108036021, -0.11060025272186746], - atol=1e-4, + self.assertTrue( + np.allclose( + summary.tValues, + [0.5054502643838291, 0.1800360108036021, -0.11060025272186746], + atol=1e-4, + ) ) - ) - self.assertTrue(np.allclose(summary.explainedVariance, 0.07997500000000031, atol=1e-4)) - self.assertTrue(np.allclose(summary.meanAbsoluteError, 0.4200000000000002, atol=1e-4)) - self.assertTrue(np.allclose(summary.meanSquaredError, 0.20212500000000005, atol=1e-4)) - self.assertTrue(np.allclose(summary.rootMeanSquaredError, 0.44958314025327956, atol=1e-4)) - self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4)) - self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414, atol=1e-4)) + self.assertTrue(np.allclose(summary.explainedVariance, 0.07997500000000031, atol=1e-4)) + self.assertTrue(np.allclose(summary.meanAbsoluteError, 0.4200000000000002, atol=1e-4)) + self.assertTrue(np.allclose(summary.meanSquaredError, 0.20212500000000005, atol=1e-4)) + self.assertTrue( + np.allclose(summary.rootMeanSquaredError, 0.44958314025327956, atol=1e-4) + ) + self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4)) + self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414, atol=1e-4)) + + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() summary2 = model.evaluate(df) self.assertTrue(isinstance(summary2, LinearRegressionSummary)) @@ -318,36 +327,43 @@ def test_generalized_linear_regression(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - # Model summary - self.assertTrue(model.hasSummary) - - summary = model.summary - self.assertIsInstance(summary, GeneralizedLinearRegressionSummary) - self.assertIsInstance(summary, GeneralizedLinearRegressionTrainingSummary) - self.assertEqual(summary.numIterations, 1) - self.assertEqual(summary.numInstances, 4) - self.assertEqual(summary.rank, 3) - self.assertTrue( - np.allclose( + def check_summary(): + # Model summary + self.assertTrue(model.hasSummary) + + summary = model.summary + self.assertIsInstance(summary, GeneralizedLinearRegressionSummary) + self.assertIsInstance(summary, GeneralizedLinearRegressionTrainingSummary) + self.assertEqual(summary.numIterations, 1) + self.assertEqual(summary.numInstances, 4) + self.assertEqual(summary.rank, 3) + self.assertTrue( + np.allclose( + summary.tValues, + [0.3725037662281711, -0.49418209022924164, 2.6589353685797654], + atol=1e-4, + ), summary.tValues, - [0.3725037662281711, -0.49418209022924164, 2.6589353685797654], - atol=1e-4, - ), - summary.tValues, - ) - self.assertTrue( - np.allclose( + ) + self.assertTrue( + np.allclose( + summary.pValues, + [0.7729938686180984, 0.707802691825973, 0.22900885781807023], + atol=1e-4, + ), summary.pValues, - [0.7729938686180984, 0.707802691825973, 0.22900885781807023], - atol=1e-4, - ), - summary.pValues, - ) - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 4) - self.assertEqual(summary.residuals().columns, ["devianceResiduals"]) - self.assertEqual(summary.residuals().count(), 4) + ) + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 4) + self.assertEqual(summary.residuals().columns, ["devianceResiduals"]) + self.assertEqual(summary.residuals().count(), 4) + check_summary() + if is_remote(): + self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) + check_summary() + + summary = model.summary summary2 = model.evaluate(df) self.assertIsInstance(summary2, GeneralizedLinearRegressionSummary) self.assertNotIsInstance(summary2, GeneralizedLinearRegressionTrainingSummary) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b86178a97c382..3e55241b07e27 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -40,6 +40,7 @@ from contextlib import contextmanager from pyspark import since +from pyspark.errors.exceptions.connect import SparkException from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote @@ -72,20 +73,6 @@ _logger = logging.getLogger("pyspark.ml.util") -def try_remote_intermediate_result(f: FuncT) -> FuncT: - """Mark the function/property that returns the intermediate result of the remote call. - Eg, model.summary""" - - @functools.wraps(f) - def wrapped(self: "JavaWrapper") -> Any: - if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - return f"{str(self._java_obj)}.{f.__name__}" - else: - return f(self) - - return cast(FuncT, wrapped) - - def invoke_helper_attr(method: str, *args: Any) -> Any: from pyspark.ml.wrapper import JavaWrapper @@ -125,7 +112,12 @@ def invoke_remote_attribute_relation( object_id = instance._java_obj # type: ignore methods, obj_ref = _extract_id_methods(object_id) methods.append(pb2.Fetch.Method(method=method, args=serialize(session.client, *args))) - plan = AttributeRelation(obj_ref, methods) + + if methods[0].method == "summary": + child = instance._summary_dataset._plan # type: ignore + else: + child = None + plan = AttributeRelation(obj_ref, methods, child=child) # To delay the GC of the model, keep a reference to the source instance, # might be a model or a summary. @@ -204,6 +196,15 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any: _logger.warning(warning_msg) remote_model_ref = RemoteModelRef(model_info.obj_ref.id) model = self._create_model(remote_model_ref) + if isinstance(model, HasTrainingSummary): + summary_dataset = model._summary_dataset(dataset) + + summary = model._summaryCls(f"{str(model._java_obj)}.summary") # type: ignore + summary._summary_dataset = summary_dataset + summary._remote_model_obj = model._java_obj # type: ignore + summary._remote_model_obj.add_ref() + + model._summary = summary # type: ignore if model.__class__.__name__ not in ["Bucketizer"]: model._resetUid(self.uid) return self._copyValues(model) @@ -278,15 +279,16 @@ def try_remote_call(f: FuncT) -> FuncT: @functools.wraps(f) def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any: - if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - # Launch a remote call if possible - import pyspark.sql.connect.proto as pb2 - from pyspark.sql.connect.session import SparkSession + import pyspark.sql.connect.proto as pb2 + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.getActiveSession() + + def remote_call() -> Any: from pyspark.ml.connect.util import _extract_id_methods from pyspark.ml.connect.serialize import serialize, deserialize from pyspark.ml.wrapper import JavaModel - session = SparkSession.getActiveSession() assert session is not None if self._java_obj == ML_CONNECT_HELPER_ID: obj_id = ML_CONNECT_HELPER_ID @@ -315,6 +317,28 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any: return model_info.obj_ref.id else: return deserialize(properties) + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + try: + return remote_call() + except SparkException as e: + if e.getErrorClass() == "CONNECT_ML.MODEL_SUMMARY_LOST": + # the model summary is lost because the remote model was offloaded, + # send request to restore model.summary + create_summary_command = pb2.Command() + create_summary_command.ml_command.create_summary.CopyFrom( + pb2.MlCommand.CreateSummary( + model_ref=pb2.ObjectRef( + id=self._remote_model_obj.ref_id # type: ignore + ), + dataset=self._summary_dataset._plan.plan( # type: ignore + session.client # type: ignore + ), + ) + ) + session.client.execute_command(create_summary_command) # type: ignore + + return remote_call() else: return f(self, name, *args) @@ -346,8 +370,11 @@ def wrapped(self: "JavaWrapper") -> Any: except Exception: return - if in_remote and isinstance(self._java_obj, RemoteModelRef): - self._java_obj.release_ref() + if in_remote: + if isinstance(self._java_obj, RemoteModelRef): + self._java_obj.release_ref() + if hasattr(self, "_remote_model_obj"): + self._remote_model_obj.release_ref() return else: return f(self) @@ -1076,17 +1103,32 @@ def hasSummary(self) -> bool: Indicates whether a training summary exists for this model instance. """ + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + return hasattr(self, "_summary") return cast("JavaWrapper", self)._call_java("hasSummary") @property @since("2.1.0") - @try_remote_intermediate_result def summary(self) -> T: """ Gets summary of the model trained on the training set. An exception is thrown if no summary exists. """ - return cast("JavaWrapper", self)._call_java("summary") + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + if hasattr(self, "_summary"): + return self._summary + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + return self._summaryCls(cast("JavaWrapper", self)._call_java("summary")) + + @property + def _summaryCls(self) -> type: + raise NotImplementedError() + + def _summary_dataset(self, train_dataset: "DataFrame") -> "DataFrame": + return self.transform(train_dataset) # type: ignore class MetaAlgorithmReadWrite: diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 34719f2b0ba6e..3cfb38fdfa7da 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1985,7 +1985,7 @@ def _create_profile(self, profile: pb2.ResourceProfile) -> int: profile_id = properties["create_resource_profile_command_result"] return profile_id - def _delete_ml_cache(self, cache_ids: List[str]) -> List[str]: + def _delete_ml_cache(self, cache_ids: List[str], evict_only: bool = False) -> List[str]: # try best to delete the cache try: if len(cache_ids) > 0: @@ -1993,6 +1993,7 @@ def _delete_ml_cache(self, cache_ids: List[str]) -> List[str]: command.ml_command.delete.obj_refs.extend( [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids] ) + command.ml_command.delete.evict_only = evict_only (_, properties, _) = self.execute_command(command) assert properties is not None diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py index 46fc82131a9e7..1ede558b94140 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.py +++ b/python/pyspark/sql/connect/proto/ml_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1a=\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_paramsB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\r\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -54,25 +54,27 @@ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001" _globals["_MLCOMMAND"]._serialized_start = 137 - _globals["_MLCOMMAND"]._serialized_end = 1595 - _globals["_MLCOMMAND_FIT"]._serialized_start = 631 - _globals["_MLCOMMAND_FIT"]._serialized_end = 809 - _globals["_MLCOMMAND_DELETE"]._serialized_start = 811 - _globals["_MLCOMMAND_DELETE"]._serialized_end = 872 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 874 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 886 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 888 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 902 - _globals["_MLCOMMAND_WRITE"]._serialized_start = 905 - _globals["_MLCOMMAND_WRITE"]._serialized_end = 1315 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1217 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1275 - _globals["_MLCOMMAND_READ"]._serialized_start = 1317 - _globals["_MLCOMMAND_READ"]._serialized_end = 1398 - _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1401 - _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1584 - _globals["_MLCOMMANDRESULT"]._serialized_start = 1598 - _globals["_MLCOMMANDRESULT"]._serialized_end = 2067 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1791 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2052 + _globals["_MLCOMMAND"]._serialized_end = 1850 + _globals["_MLCOMMAND_FIT"]._serialized_start = 712 + _globals["_MLCOMMAND_FIT"]._serialized_end = 890 + _globals["_MLCOMMAND_DELETE"]._serialized_start = 892 + _globals["_MLCOMMAND_DELETE"]._serialized_end = 1004 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1006 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1018 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1020 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1034 + _globals["_MLCOMMAND_WRITE"]._serialized_start = 1037 + _globals["_MLCOMMAND_WRITE"]._serialized_end = 1447 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1349 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1407 + _globals["_MLCOMMAND_READ"]._serialized_start = 1449 + _globals["_MLCOMMAND_READ"]._serialized_end = 1530 + _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1533 + _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1716 + _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1718 + _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1839 + _globals["_MLCOMMANDRESULT"]._serialized_start = 1853 + _globals["_MLCOMMANDRESULT"]._serialized_end = 2322 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2046 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2307 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi index 88cc6cb625ded..0a72c207b5264 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.pyi +++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi @@ -118,21 +118,39 @@ class MlCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor OBJ_REFS_FIELD_NUMBER: builtins.int + EVICT_ONLY_FIELD_NUMBER: builtins.int @property def obj_refs( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ pyspark.sql.connect.proto.ml_common_pb2.ObjectRef ]: ... + evict_only: builtins.bool + """if set `evict_only` to true, only evict the cached model from memory, + but keep the offloaded model in Spark driver local disk. + """ def __init__( self, *, obj_refs: collections.abc.Iterable[pyspark.sql.connect.proto.ml_common_pb2.ObjectRef] | None = ..., + evict_only: builtins.bool | None = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_evict_only", b"_evict_only", "evict_only", b"evict_only" + ], + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["obj_refs", b"obj_refs"] + self, + field_name: typing_extensions.Literal[ + "_evict_only", b"_evict_only", "evict_only", b"evict_only", "obj_refs", b"obj_refs" + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_evict_only", b"_evict_only"] + ) -> typing_extensions.Literal["evict_only"] | None: ... class CleanCache(google.protobuf.message.Message): """Force to clean up all the ML cached objects""" @@ -342,6 +360,34 @@ class MlCommand(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_params", b"_params"] ) -> typing_extensions.Literal["params"] | None: ... + class CreateSummary(google.protobuf.message.Message): + """This is for re-creating the model summary when the model summary is lost + (model summary is lost when the model is offloaded and then loaded back) + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + DATASET_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ... + @property + def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ... + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ..., + dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"], + ) -> None: ... + FIT_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int DELETE_FIELD_NUMBER: builtins.int @@ -350,6 +396,7 @@ class MlCommand(google.protobuf.message.Message): EVALUATE_FIELD_NUMBER: builtins.int CLEAN_CACHE_FIELD_NUMBER: builtins.int GET_CACHE_INFO_FIELD_NUMBER: builtins.int + CREATE_SUMMARY_FIELD_NUMBER: builtins.int @property def fit(self) -> global___MlCommand.Fit: ... @property @@ -366,6 +413,8 @@ class MlCommand(google.protobuf.message.Message): def clean_cache(self) -> global___MlCommand.CleanCache: ... @property def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ... + @property + def create_summary(self) -> global___MlCommand.CreateSummary: ... def __init__( self, *, @@ -377,6 +426,7 @@ class MlCommand(google.protobuf.message.Message): evaluate: global___MlCommand.Evaluate | None = ..., clean_cache: global___MlCommand.CleanCache | None = ..., get_cache_info: global___MlCommand.GetCacheInfo | None = ..., + create_summary: global___MlCommand.CreateSummary | None = ..., ) -> None: ... def HasField( self, @@ -385,6 +435,8 @@ class MlCommand(google.protobuf.message.Message): b"clean_cache", "command", b"command", + "create_summary", + b"create_summary", "delete", b"delete", "evaluate", @@ -408,6 +460,8 @@ class MlCommand(google.protobuf.message.Message): b"clean_cache", "command", b"command", + "create_summary", + b"create_summary", "delete", b"delete", "evaluate", @@ -428,7 +482,15 @@ class MlCommand(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["command", b"command"] ) -> ( typing_extensions.Literal[ - "fit", "fetch", "delete", "write", "read", "evaluate", "clean_cache", "get_cache_info" + "fit", + "fetch", + "delete", + "write", + "read", + "evaluate", + "clean_cache", + "get_cache_info", + "create_summary", ] | None ): ... diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 525ba88ff67c6..3774bcbdbfb0e 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -43,7 +43,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x38\n\ttranspose\x18* \x01(\x0b\x32\x18.spark.connect.TransposeH\x00R\ttranspose\x12w\n unresolved_table_valued_function\x18+ \x01(\x0b\x32,.spark.connect.UnresolvedTableValuedFunctionH\x00R\x1dunresolvedTableValuedFunction\x12?\n\x0clateral_join\x18, \x01(\x0b\x32\x1a.spark.connect.LateralJoinH\x00R\x0blateralJoin\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12=\n\x0bml_relation\x18\xac\x02 \x01(\x0b\x32\x19.spark.connect.MlRelationH\x00R\nmlRelation\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\xf8\x02\n\nMlRelation\x12\x43\n\ttransform\x18\x01 \x01(\x0b\x32#.spark.connect.MlRelation.TransformH\x00R\ttransform\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x1a\xeb\x01\n\tTransform\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12=\n\x0btransformer\x18\x02 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x0btransformer\x12-\n\x05input\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\n\n\x08operatorB\t\n\x07ml_type"\xcb\x02\n\x05\x46\x65tch\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x12\x35\n\x07methods\x18\x02 \x03(\x0b\x32\x1b.spark.connect.Fetch.MethodR\x07methods\x1a\xd7\x01\n\x06Method\x12\x16\n\x06method\x18\x01 \x01(\tR\x06method\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32 .spark.connect.Fetch.Method.ArgsR\x04\x61rgs\x1a\x7f\n\x04\x41rgs\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12/\n\x05input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x05inputB\x0b\n\targs_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"z\n\tTranspose\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\rindex_columns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cindexColumns"}\n\x1dUnresolvedTableValuedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xd2\x06\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x12?\n\x0cstate_schema\x18\n \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x0bstateSchema\x88\x01\x01\x12\x65\n\x19transform_with_state_info\x18\x0b \x01(\x0b\x32%.spark.connect.TransformWithStateInfoH\x04R\x16transformWithStateInfo\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_confB\x0f\n\r_state_schemaB\x1c\n\x1a_transform_with_state_info"\xdf\x01\n\x16TransformWithStateInfo\x12\x1b\n\ttime_mode\x18\x01 \x01(\tR\x08timeMode\x12\x38\n\x16\x65vent_time_column_name\x18\x02 \x01(\tH\x00R\x13\x65ventTimeColumnName\x88\x01\x01\x12\x41\n\routput_schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x01R\x0coutputSchema\x88\x01\x01\x42\x19\n\x17_event_time_column_nameB\x10\n\x0e_output_schema"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirection"\xe6\x01\n\x0bLateralJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinTypeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x38\n\ttranspose\x18* \x01(\x0b\x32\x18.spark.connect.TransposeH\x00R\ttranspose\x12w\n unresolved_table_valued_function\x18+ \x01(\x0b\x32,.spark.connect.UnresolvedTableValuedFunctionH\x00R\x1dunresolvedTableValuedFunction\x12?\n\x0clateral_join\x18, \x01(\x0b\x32\x1a.spark.connect.LateralJoinH\x00R\x0blateralJoin\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12=\n\x0bml_relation\x18\xac\x02 \x01(\x0b\x32\x19.spark.connect.MlRelationH\x00R\nmlRelation\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\xe4\x03\n\nMlRelation\x12\x43\n\ttransform\x18\x01 \x01(\x0b\x32#.spark.connect.MlRelation.TransformH\x00R\ttransform\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12P\n\x15model_summary_dataset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x01R\x13modelSummaryDataset\x88\x01\x01\x1a\xeb\x01\n\tTransform\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12=\n\x0btransformer\x18\x02 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x0btransformer\x12-\n\x05input\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\n\n\x08operatorB\t\n\x07ml_typeB\x18\n\x16_model_summary_dataset"\xcb\x02\n\x05\x46\x65tch\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x12\x35\n\x07methods\x18\x02 \x03(\x0b\x32\x1b.spark.connect.Fetch.MethodR\x07methods\x1a\xd7\x01\n\x06Method\x12\x16\n\x06method\x18\x01 \x01(\tR\x06method\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32 .spark.connect.Fetch.Method.ArgsR\x04\x61rgs\x1a\x7f\n\x04\x41rgs\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12/\n\x05input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x05inputB\x0b\n\targs_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"z\n\tTranspose\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\rindex_columns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cindexColumns"}\n\x1dUnresolvedTableValuedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xd2\x06\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x12?\n\x0cstate_schema\x18\n \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x0bstateSchema\x88\x01\x01\x12\x65\n\x19transform_with_state_info\x18\x0b \x01(\x0b\x32%.spark.connect.TransformWithStateInfoH\x04R\x16transformWithStateInfo\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_confB\x0f\n\r_state_schemaB\x1c\n\x1a_transform_with_state_info"\xdf\x01\n\x16TransformWithStateInfo\x12\x1b\n\ttime_mode\x18\x01 \x01(\tR\x08timeMode\x12\x38\n\x16\x65vent_time_column_name\x18\x02 \x01(\tH\x00R\x13\x65ventTimeColumnName\x88\x01\x01\x12\x41\n\routput_schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x01R\x0coutputSchema\x88\x01\x01\x42\x19\n\x17_event_time_column_nameB\x10\n\x0e_output_schema"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirection"\xe6\x01\n\x0bLateralJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinTypeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -81,169 +81,169 @@ _globals["_RELATION"]._serialized_start = 224 _globals["_RELATION"]._serialized_end = 3964 _globals["_MLRELATION"]._serialized_start = 3967 - _globals["_MLRELATION"]._serialized_end = 4343 - _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4097 - _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4332 - _globals["_FETCH"]._serialized_start = 4346 - _globals["_FETCH"]._serialized_end = 4677 - _globals["_FETCH_METHOD"]._serialized_start = 4462 - _globals["_FETCH_METHOD"]._serialized_end = 4677 - _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4550 - _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4677 - _globals["_UNKNOWN"]._serialized_start = 4679 - _globals["_UNKNOWN"]._serialized_end = 4688 - _globals["_RELATIONCOMMON"]._serialized_start = 4691 - _globals["_RELATIONCOMMON"]._serialized_end = 4833 - _globals["_SQL"]._serialized_start = 4836 - _globals["_SQL"]._serialized_end = 5314 - _globals["_SQL_ARGSENTRY"]._serialized_start = 5130 - _globals["_SQL_ARGSENTRY"]._serialized_end = 5220 - _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5222 - _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5314 - _globals["_WITHRELATIONS"]._serialized_start = 5316 - _globals["_WITHRELATIONS"]._serialized_end = 5433 - _globals["_READ"]._serialized_start = 5436 - _globals["_READ"]._serialized_end = 6099 - _globals["_READ_NAMEDTABLE"]._serialized_start = 5614 - _globals["_READ_NAMEDTABLE"]._serialized_end = 5806 - _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5748 - _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5806 - _globals["_READ_DATASOURCE"]._serialized_start = 5809 - _globals["_READ_DATASOURCE"]._serialized_end = 6086 - _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5748 - _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5806 - _globals["_PROJECT"]._serialized_start = 6101 - _globals["_PROJECT"]._serialized_end = 6218 - _globals["_FILTER"]._serialized_start = 6220 - _globals["_FILTER"]._serialized_end = 6332 - _globals["_JOIN"]._serialized_start = 6335 - _globals["_JOIN"]._serialized_end = 6996 - _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6674 - _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6766 - _globals["_JOIN_JOINTYPE"]._serialized_start = 6769 - _globals["_JOIN_JOINTYPE"]._serialized_end = 6977 - _globals["_SETOPERATION"]._serialized_start = 6999 - _globals["_SETOPERATION"]._serialized_end = 7478 - _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7315 - _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7429 - _globals["_LIMIT"]._serialized_start = 7480 - _globals["_LIMIT"]._serialized_end = 7556 - _globals["_OFFSET"]._serialized_start = 7558 - _globals["_OFFSET"]._serialized_end = 7637 - _globals["_TAIL"]._serialized_start = 7639 - _globals["_TAIL"]._serialized_end = 7714 - _globals["_AGGREGATE"]._serialized_start = 7717 - _globals["_AGGREGATE"]._serialized_end = 8483 - _globals["_AGGREGATE_PIVOT"]._serialized_start = 8132 - _globals["_AGGREGATE_PIVOT"]._serialized_end = 8243 - _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8245 - _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8321 - _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8324 - _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8483 - _globals["_SORT"]._serialized_start = 8486 - _globals["_SORT"]._serialized_end = 8646 - _globals["_DROP"]._serialized_start = 8649 - _globals["_DROP"]._serialized_end = 8790 - _globals["_DEDUPLICATE"]._serialized_start = 8793 - _globals["_DEDUPLICATE"]._serialized_end = 9033 - _globals["_LOCALRELATION"]._serialized_start = 9035 - _globals["_LOCALRELATION"]._serialized_end = 9124 - _globals["_CACHEDLOCALRELATION"]._serialized_start = 9126 - _globals["_CACHEDLOCALRELATION"]._serialized_end = 9198 - _globals["_CACHEDREMOTERELATION"]._serialized_start = 9200 - _globals["_CACHEDREMOTERELATION"]._serialized_end = 9255 - _globals["_SAMPLE"]._serialized_start = 9258 - _globals["_SAMPLE"]._serialized_end = 9531 - _globals["_RANGE"]._serialized_start = 9534 - _globals["_RANGE"]._serialized_end = 9679 - _globals["_SUBQUERYALIAS"]._serialized_start = 9681 - _globals["_SUBQUERYALIAS"]._serialized_end = 9795 - _globals["_REPARTITION"]._serialized_start = 9798 - _globals["_REPARTITION"]._serialized_end = 9940 - _globals["_SHOWSTRING"]._serialized_start = 9943 - _globals["_SHOWSTRING"]._serialized_end = 10085 - _globals["_HTMLSTRING"]._serialized_start = 10087 - _globals["_HTMLSTRING"]._serialized_end = 10201 - _globals["_STATSUMMARY"]._serialized_start = 10203 - _globals["_STATSUMMARY"]._serialized_end = 10295 - _globals["_STATDESCRIBE"]._serialized_start = 10297 - _globals["_STATDESCRIBE"]._serialized_end = 10378 - _globals["_STATCROSSTAB"]._serialized_start = 10380 - _globals["_STATCROSSTAB"]._serialized_end = 10481 - _globals["_STATCOV"]._serialized_start = 10483 - _globals["_STATCOV"]._serialized_end = 10579 - _globals["_STATCORR"]._serialized_start = 10582 - _globals["_STATCORR"]._serialized_end = 10719 - _globals["_STATAPPROXQUANTILE"]._serialized_start = 10722 - _globals["_STATAPPROXQUANTILE"]._serialized_end = 10886 - _globals["_STATFREQITEMS"]._serialized_start = 10888 - _globals["_STATFREQITEMS"]._serialized_end = 11013 - _globals["_STATSAMPLEBY"]._serialized_start = 11016 - _globals["_STATSAMPLEBY"]._serialized_end = 11325 - _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11217 - _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11316 - _globals["_NAFILL"]._serialized_start = 11328 - _globals["_NAFILL"]._serialized_end = 11462 - _globals["_NADROP"]._serialized_start = 11465 - _globals["_NADROP"]._serialized_end = 11599 - _globals["_NAREPLACE"]._serialized_start = 11602 - _globals["_NAREPLACE"]._serialized_end = 11898 - _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11757 - _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 11898 - _globals["_TODF"]._serialized_start = 11900 - _globals["_TODF"]._serialized_end = 11988 - _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 11991 - _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12373 - _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start = 12235 - _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end = 12302 - _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12304 - _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12373 - _globals["_WITHCOLUMNS"]._serialized_start = 12375 - _globals["_WITHCOLUMNS"]._serialized_end = 12494 - _globals["_WITHWATERMARK"]._serialized_start = 12497 - _globals["_WITHWATERMARK"]._serialized_end = 12631 - _globals["_HINT"]._serialized_start = 12634 - _globals["_HINT"]._serialized_end = 12766 - _globals["_UNPIVOT"]._serialized_start = 12769 - _globals["_UNPIVOT"]._serialized_end = 13096 - _globals["_UNPIVOT_VALUES"]._serialized_start = 13026 - _globals["_UNPIVOT_VALUES"]._serialized_end = 13085 - _globals["_TRANSPOSE"]._serialized_start = 13098 - _globals["_TRANSPOSE"]._serialized_end = 13220 - _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13222 - _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13347 - _globals["_TOSCHEMA"]._serialized_start = 13349 - _globals["_TOSCHEMA"]._serialized_end = 13455 - _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13458 - _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13661 - _globals["_MAPPARTITIONS"]._serialized_start = 13664 - _globals["_MAPPARTITIONS"]._serialized_end = 13896 - _globals["_GROUPMAP"]._serialized_start = 13899 - _globals["_GROUPMAP"]._serialized_end = 14749 - _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14752 - _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 14975 - _globals["_COGROUPMAP"]._serialized_start = 14978 - _globals["_COGROUPMAP"]._serialized_end = 15504 - _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15507 - _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15864 - _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15867 - _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16111 - _globals["_PYTHONUDTF"]._serialized_start = 16114 - _globals["_PYTHONUDTF"]._serialized_end = 16291 - _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16294 - _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16445 - _globals["_PYTHONDATASOURCE"]._serialized_start = 16447 - _globals["_PYTHONDATASOURCE"]._serialized_end = 16522 - _globals["_COLLECTMETRICS"]._serialized_start = 16525 - _globals["_COLLECTMETRICS"]._serialized_end = 16661 - _globals["_PARSE"]._serialized_start = 16664 - _globals["_PARSE"]._serialized_end = 17052 - _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5748 - _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5806 - _globals["_PARSE_PARSEFORMAT"]._serialized_start = 16953 - _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17041 - _globals["_ASOFJOIN"]._serialized_start = 17055 - _globals["_ASOFJOIN"]._serialized_end = 17530 - _globals["_LATERALJOIN"]._serialized_start = 17533 - _globals["_LATERALJOIN"]._serialized_end = 17763 + _globals["_MLRELATION"]._serialized_end = 4451 + _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4179 + _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4414 + _globals["_FETCH"]._serialized_start = 4454 + _globals["_FETCH"]._serialized_end = 4785 + _globals["_FETCH_METHOD"]._serialized_start = 4570 + _globals["_FETCH_METHOD"]._serialized_end = 4785 + _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4658 + _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4785 + _globals["_UNKNOWN"]._serialized_start = 4787 + _globals["_UNKNOWN"]._serialized_end = 4796 + _globals["_RELATIONCOMMON"]._serialized_start = 4799 + _globals["_RELATIONCOMMON"]._serialized_end = 4941 + _globals["_SQL"]._serialized_start = 4944 + _globals["_SQL"]._serialized_end = 5422 + _globals["_SQL_ARGSENTRY"]._serialized_start = 5238 + _globals["_SQL_ARGSENTRY"]._serialized_end = 5328 + _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5330 + _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5422 + _globals["_WITHRELATIONS"]._serialized_start = 5424 + _globals["_WITHRELATIONS"]._serialized_end = 5541 + _globals["_READ"]._serialized_start = 5544 + _globals["_READ"]._serialized_end = 6207 + _globals["_READ_NAMEDTABLE"]._serialized_start = 5722 + _globals["_READ_NAMEDTABLE"]._serialized_end = 5914 + _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5856 + _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5914 + _globals["_READ_DATASOURCE"]._serialized_start = 5917 + _globals["_READ_DATASOURCE"]._serialized_end = 6194 + _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5856 + _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5914 + _globals["_PROJECT"]._serialized_start = 6209 + _globals["_PROJECT"]._serialized_end = 6326 + _globals["_FILTER"]._serialized_start = 6328 + _globals["_FILTER"]._serialized_end = 6440 + _globals["_JOIN"]._serialized_start = 6443 + _globals["_JOIN"]._serialized_end = 7104 + _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6782 + _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6874 + _globals["_JOIN_JOINTYPE"]._serialized_start = 6877 + _globals["_JOIN_JOINTYPE"]._serialized_end = 7085 + _globals["_SETOPERATION"]._serialized_start = 7107 + _globals["_SETOPERATION"]._serialized_end = 7586 + _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7423 + _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7537 + _globals["_LIMIT"]._serialized_start = 7588 + _globals["_LIMIT"]._serialized_end = 7664 + _globals["_OFFSET"]._serialized_start = 7666 + _globals["_OFFSET"]._serialized_end = 7745 + _globals["_TAIL"]._serialized_start = 7747 + _globals["_TAIL"]._serialized_end = 7822 + _globals["_AGGREGATE"]._serialized_start = 7825 + _globals["_AGGREGATE"]._serialized_end = 8591 + _globals["_AGGREGATE_PIVOT"]._serialized_start = 8240 + _globals["_AGGREGATE_PIVOT"]._serialized_end = 8351 + _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8353 + _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8429 + _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8432 + _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8591 + _globals["_SORT"]._serialized_start = 8594 + _globals["_SORT"]._serialized_end = 8754 + _globals["_DROP"]._serialized_start = 8757 + _globals["_DROP"]._serialized_end = 8898 + _globals["_DEDUPLICATE"]._serialized_start = 8901 + _globals["_DEDUPLICATE"]._serialized_end = 9141 + _globals["_LOCALRELATION"]._serialized_start = 9143 + _globals["_LOCALRELATION"]._serialized_end = 9232 + _globals["_CACHEDLOCALRELATION"]._serialized_start = 9234 + _globals["_CACHEDLOCALRELATION"]._serialized_end = 9306 + _globals["_CACHEDREMOTERELATION"]._serialized_start = 9308 + _globals["_CACHEDREMOTERELATION"]._serialized_end = 9363 + _globals["_SAMPLE"]._serialized_start = 9366 + _globals["_SAMPLE"]._serialized_end = 9639 + _globals["_RANGE"]._serialized_start = 9642 + _globals["_RANGE"]._serialized_end = 9787 + _globals["_SUBQUERYALIAS"]._serialized_start = 9789 + _globals["_SUBQUERYALIAS"]._serialized_end = 9903 + _globals["_REPARTITION"]._serialized_start = 9906 + _globals["_REPARTITION"]._serialized_end = 10048 + _globals["_SHOWSTRING"]._serialized_start = 10051 + _globals["_SHOWSTRING"]._serialized_end = 10193 + _globals["_HTMLSTRING"]._serialized_start = 10195 + _globals["_HTMLSTRING"]._serialized_end = 10309 + _globals["_STATSUMMARY"]._serialized_start = 10311 + _globals["_STATSUMMARY"]._serialized_end = 10403 + _globals["_STATDESCRIBE"]._serialized_start = 10405 + _globals["_STATDESCRIBE"]._serialized_end = 10486 + _globals["_STATCROSSTAB"]._serialized_start = 10488 + _globals["_STATCROSSTAB"]._serialized_end = 10589 + _globals["_STATCOV"]._serialized_start = 10591 + _globals["_STATCOV"]._serialized_end = 10687 + _globals["_STATCORR"]._serialized_start = 10690 + _globals["_STATCORR"]._serialized_end = 10827 + _globals["_STATAPPROXQUANTILE"]._serialized_start = 10830 + _globals["_STATAPPROXQUANTILE"]._serialized_end = 10994 + _globals["_STATFREQITEMS"]._serialized_start = 10996 + _globals["_STATFREQITEMS"]._serialized_end = 11121 + _globals["_STATSAMPLEBY"]._serialized_start = 11124 + _globals["_STATSAMPLEBY"]._serialized_end = 11433 + _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11325 + _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11424 + _globals["_NAFILL"]._serialized_start = 11436 + _globals["_NAFILL"]._serialized_end = 11570 + _globals["_NADROP"]._serialized_start = 11573 + _globals["_NADROP"]._serialized_end = 11707 + _globals["_NAREPLACE"]._serialized_start = 11710 + _globals["_NAREPLACE"]._serialized_end = 12006 + _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11865 + _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12006 + _globals["_TODF"]._serialized_start = 12008 + _globals["_TODF"]._serialized_end = 12096 + _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12099 + _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12481 + _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start = 12343 + _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end = 12410 + _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12412 + _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12481 + _globals["_WITHCOLUMNS"]._serialized_start = 12483 + _globals["_WITHCOLUMNS"]._serialized_end = 12602 + _globals["_WITHWATERMARK"]._serialized_start = 12605 + _globals["_WITHWATERMARK"]._serialized_end = 12739 + _globals["_HINT"]._serialized_start = 12742 + _globals["_HINT"]._serialized_end = 12874 + _globals["_UNPIVOT"]._serialized_start = 12877 + _globals["_UNPIVOT"]._serialized_end = 13204 + _globals["_UNPIVOT_VALUES"]._serialized_start = 13134 + _globals["_UNPIVOT_VALUES"]._serialized_end = 13193 + _globals["_TRANSPOSE"]._serialized_start = 13206 + _globals["_TRANSPOSE"]._serialized_end = 13328 + _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13330 + _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13455 + _globals["_TOSCHEMA"]._serialized_start = 13457 + _globals["_TOSCHEMA"]._serialized_end = 13563 + _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13566 + _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13769 + _globals["_MAPPARTITIONS"]._serialized_start = 13772 + _globals["_MAPPARTITIONS"]._serialized_end = 14004 + _globals["_GROUPMAP"]._serialized_start = 14007 + _globals["_GROUPMAP"]._serialized_end = 14857 + _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14860 + _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15083 + _globals["_COGROUPMAP"]._serialized_start = 15086 + _globals["_COGROUPMAP"]._serialized_end = 15612 + _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15615 + _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15972 + _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15975 + _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16219 + _globals["_PYTHONUDTF"]._serialized_start = 16222 + _globals["_PYTHONUDTF"]._serialized_end = 16399 + _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16402 + _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16553 + _globals["_PYTHONDATASOURCE"]._serialized_start = 16555 + _globals["_PYTHONDATASOURCE"]._serialized_end = 16630 + _globals["_COLLECTMETRICS"]._serialized_start = 16633 + _globals["_COLLECTMETRICS"]._serialized_end = 16769 + _globals["_PARSE"]._serialized_start = 16772 + _globals["_PARSE"]._serialized_end = 17160 + _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5856 + _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5914 + _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17061 + _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17149 + _globals["_ASOFJOIN"]._serialized_start = 17163 + _globals["_ASOFJOIN"]._serialized_end = 17638 + _globals["_LATERALJOIN"]._serialized_start = 17641 + _globals["_LATERALJOIN"]._serialized_end = 17871 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index beeeb712da762..e1eb7945c19f0 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -707,28 +707,57 @@ class MlRelation(google.protobuf.message.Message): TRANSFORM_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int + MODEL_SUMMARY_DATASET_FIELD_NUMBER: builtins.int @property def transform(self) -> global___MlRelation.Transform: ... @property def fetch(self) -> global___Fetch: ... + @property + def model_summary_dataset(self) -> global___Relation: + """(Optional) the dataset for restoring the model summary""" def __init__( self, *, transform: global___MlRelation.Transform | None = ..., fetch: global___Fetch | None = ..., + model_summary_dataset: global___Relation | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "fetch", b"fetch", "ml_type", b"ml_type", "transform", b"transform" + "_model_summary_dataset", + b"_model_summary_dataset", + "fetch", + b"fetch", + "ml_type", + b"ml_type", + "model_summary_dataset", + b"model_summary_dataset", + "transform", + b"transform", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "fetch", b"fetch", "ml_type", b"ml_type", "transform", b"transform" + "_model_summary_dataset", + b"_model_summary_dataset", + "fetch", + b"fetch", + "ml_type", + b"ml_type", + "model_summary_dataset", + b"model_summary_dataset", + "transform", + b"transform", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["_model_summary_dataset", b"_model_summary_dataset"], + ) -> typing_extensions.Literal["model_summary_dataset"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["ml_type", b"ml_type"] ) -> typing_extensions.Literal["transform", "fetch"] | None: ... diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto index b66c0a186df39..3497284af4ab8 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto @@ -38,6 +38,7 @@ message MlCommand { Evaluate evaluate = 6; CleanCache clean_cache = 7; GetCacheInfo get_cache_info = 8; + CreateSummary create_summary = 9; } // Command for estimator.fit(dataset) @@ -54,6 +55,9 @@ message MlCommand { // or summary evaluated by a model message Delete { repeated ObjectRef obj_refs = 1; + // if set `evict_only` to true, only evict the cached model from memory, + // but keep the offloaded model in Spark driver local disk. + optional bool evict_only = 2; } // Force to clean up all the ML cached objects @@ -98,6 +102,13 @@ message MlCommand { // (Required) the evaluating dataset Relation dataset = 3; } + + // This is for re-creating the model summary when the model summary is lost + // (model summary is lost when the model is offloaded and then loaded back) + message CreateSummary { + ObjectRef model_ref = 1; + Relation dataset = 2; + } } // The result of MlCommand diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto index 70a52a2111494..ccb674e812dc0 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -115,6 +115,9 @@ message MlRelation { Transform transform = 1; Fetch fetch = 2; } + // (Optional) the dataset for restoring the model summary + optional Relation model_summary_dataset = 3; + // Relation to represent transform(input) of the operator // which could be a cached model or a new transformer message Transform { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index ef1b17dc2221e..b075187b7002f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -30,7 +30,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.Model -import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary} +import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary, MLWritable, Summary} import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder @@ -115,6 +115,12 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } + private[spark] def getModelOffloadingPath(refId: String): Path = { + val path = offloadedModelsDir.resolve(refId) + require(path.startsWith(offloadedModelsDir)) + path + } + /** * Cache an object into a map of MLCache, and return its key * @param obj @@ -137,9 +143,14 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } cachedModel.put(objectId, CacheItem(obj, sizeBytes)) if (getMemoryControlEnabled) { - val savePath = offloadedModelsDir.resolve(objectId) - require(savePath.startsWith(offloadedModelsDir)) + val savePath = getModelOffloadingPath(objectId) obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString) + if (obj.isInstanceOf[HasTrainingSummary[_]] + && obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) { + obj + .asInstanceOf[HasTrainingSummary[_]] + .saveSummary(savePath.resolve("summary").toString) + } Files.writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName) totalMLCacheInMemorySizeBytes.addAndGet(sizeBytes) totalMLCacheSizeBytes.addAndGet(sizeBytes) @@ -176,8 +187,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { verifyObjectId(refId) var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse(null) if (obj == null && getMemoryControlEnabled) { - val loadPath = offloadedModelsDir.resolve(refId) - require(loadPath.startsWith(offloadedModelsDir)) + val loadPath = getModelOffloadingPath(refId) if (Files.isDirectory(loadPath)) { val className = Files.readString(loadPath.resolve(modelClassNameFile)) obj = MLUtils.loadTransformer( @@ -194,14 +204,13 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } - def _removeModel(refId: String): Boolean = { + def _removeModel(refId: String, evictOnly: Boolean): Boolean = { verifyObjectId(refId) val removedModel = cachedModel.remove(refId) val removedFromMem = removedModel != null - val removedFromDisk = if (removedModel != null && getMemoryControlEnabled) { + val removedFromDisk = if (!evictOnly && removedModel != null && getMemoryControlEnabled) { totalMLCacheSizeBytes.addAndGet(-removedModel.sizeBytes) - val removePath = offloadedModelsDir.resolve(refId) - require(removePath.startsWith(offloadedModelsDir)) + val removePath = getModelOffloadingPath(refId) val offloadingPath = new File(removePath.toString) if (offloadingPath.exists()) { FileUtils.deleteDirectory(offloadingPath) @@ -220,8 +229,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { * @param refId * the key used to look up the corresponding object */ - def remove(refId: String): Boolean = { - val modelIsRemoved = _removeModel(refId) + def remove(refId: String, evictOnly: Boolean = false): Boolean = { + val modelIsRemoved = _removeModel(refId, evictOnly) modelIsRemoved } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala index a017c719ed16e..847052be98a98 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala @@ -51,3 +51,9 @@ private[spark] case class MLCacheSizeOverflowException(mlCacheMaxSize: Long) errorClass = "CONNECT_ML.ML_CACHE_SIZE_OVERFLOW_EXCEPTION", messageParameters = Map("mlCacheMaxSize" -> mlCacheMaxSize.toString), cause = null) + +private[spark] case class MLModelSummaryLostException(objectName: String) + extends SparkException( + errorClass = "CONNECT_ML.MODEL_SUMMARY_LOST", + messageParameters = Map("objectName" -> objectName), + cause = null) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index d40b70ba0813c..7220acb8feaca 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -229,9 +229,7 @@ private[connect] object MLHandler extends Logging { if (obj != null && obj.isInstanceOf[HasTrainingSummary[_]] && methods(0).getMethod == "summary" && !obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) { - throw MLCacheInvalidException( - objRefId, - sessionHolder.mlCache.getOffloadingTimeoutMinute) + throw MLModelSummaryLostException(objRefId) } val helper = AttributeHelper(sessionHolder, objRefId, methods) val attrResult = helper.getAttribute @@ -264,9 +262,13 @@ private[connect] object MLHandler extends Logging { case proto.MlCommand.CommandCase.DELETE => val ids = mutable.ArrayBuilder.make[String] - mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId => + val deleteCmd = mlCommand.getDelete + val evictOnly = if (deleteCmd.hasEvictOnly) { + deleteCmd.getEvictOnly + } else { false } + deleteCmd.getObjRefsList.asScala.toArray.foreach { objId => if (!objId.getId.contains(".")) { - if (mlCache.remove(objId.getId)) { + if (mlCache.remove(objId.getId, evictOnly)) { ids += objId.getId } } @@ -400,10 +402,29 @@ private[connect] object MLHandler extends Logging { .setParam(LiteralValueProtoConverter.toLiteralProto(metric)) .build() + case proto.MlCommand.CommandCase.CREATE_SUMMARY => + val createSummaryCmd = mlCommand.getCreateSummary + createModelSummary(sessionHolder, createSummaryCmd) + case other => throw MlUnsupportedException(s"$other not supported") } } + private def createModelSummary( + sessionHolder: SessionHolder, + createSummaryCmd: proto.MlCommand.CreateSummary): proto.MlCommandResult = { + val refId = createSummaryCmd.getModelRef.getId + val model = sessionHolder.mlCache.get(refId).asInstanceOf[HasTrainingSummary[_]] + val dataset = MLUtils.parseRelationProto(createSummaryCmd.getDataset, sessionHolder) + val modelPath = sessionHolder.mlCache.getModelOffloadingPath(refId) + val summaryPath = modelPath.resolve("summary").toString + model.loadSummary(summaryPath, dataset) + proto.MlCommandResult + .newBuilder() + .setParam(LiteralValueProtoConverter.toLiteralProto(true)) + .build() + } + def transformMLRelation(relation: proto.MlRelation, sessionHolder: SessionHolder): DataFrame = { relation.getMlTypeCase match { // Ml transform @@ -433,10 +454,26 @@ private[connect] object MLHandler extends Logging { // Get the attribute from a cached object which could be a model or summary case proto.MlRelation.MlTypeCase.FETCH => - val helper = AttributeHelper( - sessionHolder, - relation.getFetch.getObjRef.getId, - relation.getFetch.getMethodsList.asScala.toArray) + val objRefId = relation.getFetch.getObjRef.getId + val methods = relation.getFetch.getMethodsList.asScala.toArray + val obj = sessionHolder.mlCache.get(objRefId) + if (obj != null && obj.isInstanceOf[HasTrainingSummary[_]] + && methods(0).getMethod == "summary" + && !obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) { + + if (relation.hasModelSummaryDataset) { + val dataset = + MLUtils.parseRelationProto(relation.getModelSummaryDataset, sessionHolder) + val modelPath = sessionHolder.mlCache.getModelOffloadingPath(objRefId) + val summaryPath = modelPath.resolve("summary").toString + obj.asInstanceOf[HasTrainingSummary[_]].loadSummary(summaryPath, dataset) + } else { + // For old Spark client backward compatibility. + throw MLModelSummaryLostException(objRefId) + } + } + + val helper = AttributeHelper(sessionHolder, objRefId, methods) helper.getAttribute.asInstanceOf[DataFrame] case other => throw MlUnsupportedException(s"$other not supported")