Skip to content

[SPARK-52470][ML][CONNECT] Support model summary offloading #51187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,11 @@
"Please fit or load a model smaller than <modelMaxSize> bytes."
]
},
"MODEL_SUMMARY_LOST" : {
"message" : [
"The model <objectName> summary is lost because the cached model is offloaded."
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,8 @@ class FMClassifier @Since("3.0.0") (
factors: Matrix,
objectiveHistory: Array[Double]): FMClassificationModel = {
val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
val summary = new FMClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
weightColName,
objectiveHistory)
model.setSummary(Some(summary))
model.createSummary(dataset, objectiveHistory)
model
}

@Since("3.0.0")
Expand Down Expand Up @@ -343,6 +334,42 @@ class FMClassificationModel private[classification] (
s"uid=${super.toString}, numClasses=$numClasses, numFeatures=$numFeatures, " +
s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}"
}

private[spark] def createSummary(
dataset: Dataset[_], objectiveHistory: Array[Double]
): Unit = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
val summary = new FMClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
weightColName,
objectiveHistory)
setSummary(Some(summary))
}

override private[spark] def saveSummary(path: String): Unit = {
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
path, Tuple1(summary.objectiveHistory),
(data, dos) => {
ReadWriteUtils.serializeDoubleArray(data._1, dos)
}
)
}

override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
val Tuple1(objectiveHistory: Array[Double])
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
path,
dis => {
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
}
)
createSummary(dataset, objectiveHistory)
}
}

@Since("3.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,8 @@ class LinearSVC @Since("2.2.0") (
intercept: Double,
objectiveHistory: Array[Double]): LinearSVCModel = {
val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
val summary = new LinearSVCTrainingSummaryImpl(
summaryModel.transform(dataset),
rawPredictionColName,
predictionColName,
$(labelCol),
weightColName,
objectiveHistory)
model.setSummary(Some(summary))
model.createSummary(dataset, objectiveHistory)
model
}

private def trainImpl(
Expand Down Expand Up @@ -445,6 +436,42 @@ class LinearSVCModel private[classification] (
override def toString: String = {
s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
}

private[spark] def createSummary(
dataset: Dataset[_], objectiveHistory: Array[Double]
): Unit = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, rawPredictionColName, predictionColName) = findSummaryModel()
val summary = new LinearSVCTrainingSummaryImpl(
summaryModel.transform(dataset),
rawPredictionColName,
predictionColName,
$(labelCol),
weightColName,
objectiveHistory)
setSummary(Some(summary))
}

override private[spark] def saveSummary(path: String): Unit = {
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
path, Tuple1(summary.objectiveHistory),
(data, dos) => {
ReadWriteUtils.serializeDoubleArray(data._1, dos)
}
)
}

override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
val Tuple1(objectiveHistory: Array[Double])
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
path,
dis => {
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
}
)
createSummary(dataset, objectiveHistory)
}
}

@Since("2.2.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,29 +718,8 @@ class LogisticRegression @Since("1.2.0") (
objectiveHistory: Array[Double]): LogisticRegressionModel = {
val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, checkMultinomial(numClasses)))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
val logRegSummary = if (numClasses <= 2) {
new BinaryLogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
$(featuresCol),
weightColName,
objectiveHistory)
} else {
new LogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
$(featuresCol),
weightColName,
objectiveHistory)
}
model.setSummary(Some(logRegSummary))
model.createSummary(dataset, objectiveHistory)
model
}

private def createBounds(
Expand Down Expand Up @@ -1323,6 +1302,54 @@ class LogisticRegressionModel private[spark] (
override def toString: String = {
s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
}

private[spark] def createSummary(
dataset: Dataset[_], objectiveHistory: Array[Double]
): Unit = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
val logRegSummary = if (numClasses <= 2) {
new BinaryLogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
$(featuresCol),
weightColName,
objectiveHistory)
} else {
new LogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
$(featuresCol),
weightColName,
objectiveHistory)
}
setSummary(Some(logRegSummary))
}

override private[spark] def saveSummary(path: String): Unit = {
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
path, Tuple1(summary.objectiveHistory),
(data, dos) => {
ReadWriteUtils.serializeDoubleArray(data._1, dos)
}
)
}

override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
val Tuple1(objectiveHistory: Array[Double])
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
path,
dis => {
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
}
)
createSummary(dataset, objectiveHistory)
}
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = {
val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights))

val (summaryModel, _, predictionColName) = model.findSummaryModel()
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
"",
objectiveHistory)
model.setSummary(Some(summary))
model.createSummary(dataset, objectiveHistory)
model
}
}

Expand Down Expand Up @@ -365,6 +359,39 @@ class MultilayerPerceptronClassificationModel private[ml] (
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${$(layers).length}, " +
s"numClasses=$numClasses, numFeatures=$numFeatures"
}

private[spark] def createSummary(
dataset: Dataset[_], objectiveHistory: Array[Double]
): Unit = {
val (summaryModel, _, predictionColName) = findSummaryModel()
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
"",
objectiveHistory)
setSummary(Some(summary))
}

override private[spark] def saveSummary(path: String): Unit = {
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
path, Tuple1(summary.objectiveHistory),
(data, dos) => {
ReadWriteUtils.serializeDoubleArray(data._1, dos)
}
)
}

override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
val Tuple1(objectiveHistory: Array[Double])
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
path,
dis => {
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
}
)
createSummary(dataset, objectiveHistory)
}
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,8 @@ class RandomForestClassifier @Since("1.4.0") (
numFeatures: Int,
numClasses: Int): RandomForestClassificationModel = {
val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
val rfSummary = if (numClasses <= 2) {
new BinaryRandomForestClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
weightColName,
Array(0.0))
} else {
new RandomForestClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
weightColName,
Array(0.0))
}
model.setSummary(Some(rfSummary))
model.createSummary(dataset)
model
}

@Since("1.4.1")
Expand Down Expand Up @@ -393,6 +375,35 @@ class RandomForestClassificationModel private[ml] (
@Since("2.0.0")
override def write: MLWriter =
new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)

private[spark] def createSummary(dataset: Dataset[_]): Unit = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
val rfSummary = if (numClasses <= 2) {
new BinaryRandomForestClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
predictionColName,
$(labelCol),
weightColName,
Array(0.0))
} else {
new RandomForestClassificationTrainingSummaryImpl(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
weightColName,
Array(0.0))
}
setSummary(Some(rfSummary))
}

override private[spark] def saveSummary(path: String): Unit = {}

override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
createSummary(dataset)
}
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ class BisectingKMeansModel private[ml] (
override def summary: BisectingKMeansSummary = super.summary

override def estimatedSize: Long = SizeEstimator.estimate(parentModel)

// BisectingKMeans model hasn't supported offloading, so put an empty `saveSummary` here for now
override private[spark] def saveSummary(path: String): Unit = {}
}

object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
Expand Down
Loading