From fb080e824cd854435de41d830b70ed9a9905bedc Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 12 Aug 2024 19:25:42 +0800 Subject: [PATCH] simplify Signed-off-by: Weichen Xu --- .../apache/spark/ml/feature/Bucketizer.scala | 10 ++------- .../org/apache/spark/ml/feature/Imputer.scala | 6 +++--- .../spark/ml/feature/OneHotEncoder.scala | 11 ++-------- .../ml/feature/QuantileDiscretizer.scala | 11 ++-------- .../apache/spark/ml/util/SchemaUtils.scala | 21 +++++-------------- 5 files changed, 14 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 54704ae884d14..9a30c0895c17f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -192,10 +192,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), Seq(outputCols, splitsArray)) - val sparkSession = SparkSession.getDefaultSession.get - val transformDataset = sparkSession.createDataFrame( - new ju.ArrayList[Row](), schema = schema - ) if (isSet(inputCols)) { require(getInputCols.length == getOutputCols.length && getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + @@ -205,15 +201,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String var transformedSchema = schema $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => - val colType = transformDataset.col(inputCol).expr.dataType - SchemaUtils.checkNumericType(colType, inputCol, "") + SchemaUtils.checkNumericType(schema, inputCol) transformedSchema = SchemaUtils.appendColumn(transformedSchema, prepOutputField($(splitsArray)(idx), outputCol)) } transformedSchema } else { - val colType = transformDataset.col($(inputCol)).expr.dataType - SchemaUtils.checkNumericType(colType, $(inputCol), "") + SchemaUtils.checkNumericType(schema, $(inputCol)) SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 78d23e182b7ea..07d84720b1754 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -91,7 +91,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp s" and outputCols(${outputColNames.length}) should have the same length") val outputFields = inputColNames.zip(outputColNames).map { case (inputCol, outputCol) => val inputField = SchemaUtils.getSchemaField(schema, inputCol) - SchemaUtils.checkNumericType(inputField.dataType, inputCol, "") + SchemaUtils.checkNumericType(schema, inputCol) StructField(outputCol, inputField.dataType, inputField.nullable) } StructType(schema ++ outputFields) @@ -179,8 +179,8 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) val quantileDataset = dataset.select(inputColumns.zipWithIndex.map { case (colName, index) => col(colName).alias(quantileColNames(index)) }.toImmutableArraySeq: _*) - quantileDataset.select(cols.toImmutableArraySeq: _*) - .stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) + quantileDataset + .stat.approxQuantile(quantileColNames, Array(0.5), $(relativeError)) .map(_.headOption.getOrElse(Double.NaN)) case Imputer.mode => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index fd10040194388..44b8b2047681b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import java.{util => ju} - import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -29,7 +27,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -90,14 +88,9 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid s"The number of input columns ${inputColNames.length} must be the same as the number of " + s"output columns ${outputColNames.length}.") - val sparkSession = SparkSession.getDefaultSession.get - val transformDataset = sparkSession.createDataFrame( - new ju.ArrayList[Row](), schema = schema - ) // Input columns must be NumericType. inputColNames.foreach { colName => - val dataType = transformDataset.col(colName).expr.dataType - SchemaUtils.checkNumericType(dataType, colName, "") + SchemaUtils.checkNumericType(schema, colName) } // Prepares output columns with proper attributes by examining input columns. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 18bb5d5151fed..fbec76cc79cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import java.{util => ju} - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ @@ -190,13 +188,8 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui var outputFields = schema.fields - val sparkSession = SparkSession.getDefaultSession.get - val transformDataset = sparkSession.createDataFrame( - new ju.ArrayList[Row](), schema = schema - ) inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => - val colType = transformDataset.col(inputColName).expr.dataType - SchemaUtils.checkNumericType(colType, inputColName, "") + SchemaUtils.checkNumericType(schema, inputColName) require(!schema.fieldNames.contains(outputColName), s"Output column $outputColName already exists.") val attr = NominalAttribute.defaultAttr.withName(outputColName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 70f6ee1b617be..5b0e47d75b80a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -73,8 +73,11 @@ private[spark] object SchemaUtils { schema: StructType, colName: String, msg: String): Unit = { - val actualDataType = schema(colName).dataType - checkNumericType(actualDataType, colName, msg) + val actualDataType = getSchemaFieldType(schema, colName) + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(actualDataType.isInstanceOf[NumericType], + s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** @@ -87,20 +90,6 @@ private[spark] object SchemaUtils { checkNumericType(schema, colName, "") } - /** - * Check whether the given actual data type is the numeric data type. - * @param actualDataType actual data type of the column - */ - def checkNumericType( - actualDataType: DataType, - colName: String, - msg: String): Unit = { - val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], - s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + - s"${actualDataType.catalogString}.$message") - } - /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema