Skip to content

Commit

Permalink
[SPARK-51091][ML][PYTHON][CONNECT] Fix the default params of `StopWor…
Browse files Browse the repository at this point in the history
…dsRemover`

### What changes were proposed in this pull request?
Fix the default params of `StopWordsRemover`

### Why are the changes needed?
for feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49809 from zhengruifeng/ml_connect_swr.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 4fd750c)
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 5, 2025
1 parent f916468 commit 9243ff6
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import java.util.Locale

import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
Expand Down Expand Up @@ -122,21 +122,6 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
@Since("2.4.0")
def getLocale: String = $(locale)

/**
* Returns system default locale, or `Locale.US` if the default locale is not in available locales
* in JVM.
*/
private val getDefaultOrUS: Locale = {
if (Locale.getAvailableLocales.contains(Locale.getDefault)) {
Locale.getDefault
} else {
logWarning(log"Default locale set was [${MDC(LogKeys.LOCALE, Locale.getDefault)}]; " +
log"however, it was not found in available locales in JVM, falling back to en_US locale. " +
log"Set param `locale` in order to respect another locale.")
Locale.US
}
}

/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
Expand All @@ -147,7 +132,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
}

setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive -> false, locale -> getDefaultOrUS.toString)
caseSensitive -> false, locale -> StopWordsRemover.getDefaultOrUS.toString)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
Expand Down Expand Up @@ -218,7 +203,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
}

@Since("1.6.0")
object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] with Logging {

private[feature]
val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
Expand All @@ -241,4 +226,15 @@ object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
val is = getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt")
scala.io.Source.fromInputStream(is)(scala.io.Codec.UTF8).getLines().toArray
}

private[spark] def getDefaultOrUS: Locale = {
if (Locale.getAvailableLocales.contains(Locale.getDefault)) {
Locale.getDefault
} else {
logWarning(log"Default locale set was [${MDC(LogKeys.LOCALE, Locale.getDefault)}]; " +
log"however, it was not found in available locales in JVM, falling back to en_US locale. " +
log"Set param `locale` in order to respect another locale.")
Locale.US
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.ml.util

import org.apache.spark.ml.Model
import org.apache.spark.ml.feature.{CountVectorizerModel, StringIndexerModel}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -46,6 +46,14 @@ private[spark] class ConnectHelper(override val uid: String) extends Model[Conne
new CountVectorizerModel(uid, vocabulary)
}

def stopWordsRemoverLoadDefaultStopWords(language: String): Array[String] = {
StopWordsRemover.loadDefaultStopWords(language)
}

def stopWordsRemoverGetDefaultOrUS: String = {
StopWordsRemover.getDefaultOrUS.toString
}

override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)

override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
Expand Down
30 changes: 15 additions & 15 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5043,7 +5043,6 @@ class StopWordsRemover(
Notes
-----
- null values from input array are preserved unless adding null to stopWords explicitly.
- In Spark Connect Mode, the default value of parameter `locale` and `stopWords` are not set.
Examples
--------
Expand Down Expand Up @@ -5142,19 +5141,14 @@ def __init__(
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.StopWordsRemover", self.uid
)
if isinstance(self._java_obj, str):
# Skip setting the default value of 'locale' and 'stopWords', which
# needs to invoke a JVM method.
# So if users don't explicitly set 'locale' and/or 'stopWords', then the getters fails.
self._setDefault(
caseSensitive=False,
)
if is_remote():
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
locale = helper._call_java("stopWordsRemoverGetDefaultOrUS")
else:
self._setDefault(
stopWords=StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive=False,
locale=self._java_obj.getLocale(),
)
locale = self._java_obj.getLocale()

stopWords = StopWordsRemover.loadDefaultStopWords("english")
self._setDefault(stopWords=stopWords, caseSensitive=False, locale=locale)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand Down Expand Up @@ -5279,8 +5273,14 @@ def loadDefaultStopWords(language: str) -> List[str]:
Supported languages: danish, dutch, english, finnish, french, german, hungarian,
italian, norwegian, portuguese, russian, spanish, swedish, turkish
"""
stopWordsObj = getattr(_jvm(), "org.apache.spark.ml.feature.StopWordsRemover")
return list(stopWordsObj.loadDefaultStopWords(language))
if is_remote():
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
stopWords = helper._call_java("stopWordsRemoverLoadDefaultStopWords", language)
return list(stopWords)

else:
stopWordsObj = getattr(_jvm(), "org.apache.spark.ml.feature.StopWordsRemover")
return list(stopWordsObj.loadDefaultStopWords(language))


class _TargetEncoderParams(
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/ml/tests/connect/test_parity_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@


class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
@unittest.skip("Need to support.")
def test_stop_words_lengague_selection(self):
super().test_stop_words_lengague_selection()
pass


if __name__ == "__main__":
Expand Down
24 changes: 20 additions & 4 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,8 +883,9 @@ def test_stop_words_remover(self):
remover2 = StopWordsRemover.load(d)
self.assertEqual(str(remover), str(remover2))

def test_stop_words_remover_II(self):
dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
def test_stop_words_remover_with_given_words(self):
spark = self.spark
dataset = spark.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
# Default
self.assertEqual(stopWordRemover.getInputCol(), "input")
Expand All @@ -905,15 +906,30 @@ def test_stop_words_remover_II(self):
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_stop_words_language_selection(self):
def test_stop_words_remover_with_turkish(self):
spark = self.spark
dataset = spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_stop_words_remover_default(self):
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")

# check the default value of local
locale = stopWordRemover.getLocale()
self.assertIsInstance(locale, str)
self.assertTrue(len(locale) > 0)

# check the default value of stop words
stopwords = stopWordRemover.getStopWords()
self.assertIsInstance(stopwords, list)
self.assertTrue(len(stopwords) > 0)
self.assertTrue(all(isinstance(word, str) for word in stopwords))

def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ private[ml] object MLUtils {
"handleOverwrite",
"stringIndexerModelFromLabels",
"stringIndexerModelFromLabelsArray",
"countVectorizerModelFromVocabulary")))
"countVectorizerModelFromVocabulary",
"stopWordsRemoverLoadDefaultStopWords",
"stopWordsRemoverGetDefaultOrUS")))

private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
Expand Down

0 comments on commit 9243ff6

Please sign in to comment.