Skip to content

Commit

Permalink
[jvm-packages] LTR: distribute the features with same group into same…
Browse files Browse the repository at this point in the history
… partition (#11023)
  • Loading branch information
wbo4958 authored Dec 3, 2024
1 parent a03b92e commit e25d56d
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,55 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
}
}

test("Same group must be in the same partition") {
val num_workers = 3
withGpuSparkSession() { spark =>
import spark.implicits._
val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(
(0.1, 1, 0),
(0.1, 1, 0),
(0.1, 1, 0),
(0.1, 1, 1),
(0.1, 1, 1),
(0.1, 1, 1),
(0.1, 1, 2),
(0.1, 1, 2),
(0.1, 1, 2)), 1)).toDF("label", "f1", "group")

// The original pattern will repartition df in a RoundRobin manner
val oriRows = df.repartition(num_workers)
.sortWithinPartitions(df.col("group"))
.select("group")
.mapPartitions { case iter =>
val tmp: ArrayBuffer[Int] = ArrayBuffer.empty
while (iter.hasNext) {
val r = iter.next()
tmp.append(r.getInt(0))
}
Iterator.single(tmp.mkString(","))
}.collect()
assert(oriRows.length == 3)
assert(oriRows.contains("0,1,2"))

// The fix has replaced repartition with repartitionByRange which will put the
// instances with same group into the same partition
val ranker = new XGBoostRanker().setGroupCol("group").setNumWorkers(num_workers)
val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df)
val rows = processedDf
.select("group")
.mapPartitions { case iter =>
val tmp: ArrayBuffer[Int] = ArrayBuffer.empty
while (iter.hasNext) {
val r = iter.next()
tmp.append(r.getInt(0))
}
Iterator.single(tmp.mkString(","))
}.collect()

rows.forall(Seq("0,0,0", "1,1,1", "2,2,2").contains)
}
}

test("Ranker: XGBoost-Spark should match xgboost4j") {
withGpuSparkSession() { spark =>
import spark.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
import org.apache.spark.ml.xgboost.SparkUtils
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

import ml.dmlc.xgboost4j.scala.Booster
Expand Down Expand Up @@ -62,6 +63,22 @@ class XGBoostRanker(override val uid: String,
}
}

/**
* Repartition the dataset to the numWorkers if needed.
*
* @param dataset to be repartition
* @return the repartitioned dataset
*/
override private[spark] def repartitionIfNeeded(dataset: Dataset[_]) = {
val numPartitions = dataset.rdd.getNumPartitions
if (getForceRepartition || getNumWorkers != numPartitions) {
// Please note that the output of repartitionByRange is not deterministic
dataset.repartitionByRange(getNumWorkers, col(getGroupCol))
} else {
dataset
}
}

/**
* Sort partition for Ranker issue.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,54 @@ class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite
}}
}

test("Same group must be in the same partition") {
val spark = ss
import spark.implicits._
val num_workers = 3
val df = ss.createDataFrame(sc.parallelize(Seq(
(0.1, Vectors.dense(1.0, 2.0, 3.0), 0),
(0.1, Vectors.dense(0.0, 0.0, 0.0), 0),
(0.1, Vectors.dense(0.0, 3.0, 0.0), 0),
(0.1, Vectors.dense(2.0, 0.0, 4.0), 1),
(0.1, Vectors.dense(0.2, 1.2, 2.0), 1),
(0.1, Vectors.dense(0.5, 2.2, 1.7), 1),
(0.1, Vectors.dense(0.5, 2.2, 1.7), 2),
(0.1, Vectors.dense(0.5, 2.2, 1.7), 2),
(0.1, Vectors.dense(0.5, 2.2, 1.7), 2)), 1)).toDF("label", "features", "group")

// The original pattern will repartition df in a RoundRobin manner
val oriRows = df.repartition(num_workers)
.sortWithinPartitions(df.col("group"))
.select("group")
.mapPartitions { case iter =>
val tmp: ArrayBuffer[Int] = ArrayBuffer.empty
while (iter.hasNext) {
val r = iter.next()
tmp.append(r.getInt(0))
}
Iterator.single(tmp.mkString(","))
}.collect()
assert(oriRows.length == 3)
assert(oriRows.contains("0,1,2"))

// The fix has replaced repartition with repartitionByRange which will put the
// instances with same group into the same partition
val ranker = new XGBoostRanker().setGroupCol("group").setNumWorkers(num_workers)
val (processedDf, _) = ranker.preprocess(df)
val rows = processedDf
.select("group")
.mapPartitions { case iter =>
val tmp: ArrayBuffer[Int] = ArrayBuffer.empty
while (iter.hasNext) {
val r = iter.next()
tmp.append(r.getInt(0))
}
Iterator.single(tmp.mkString(","))
}.collect()

rows.forall(Seq("0,0,0", "1,1,1", "2,2,2").contains)
}

private def runLengthEncode(input: Seq[Int]): Seq[Int] = {
if (input.isEmpty) return Seq(0)

Expand Down

0 comments on commit e25d56d

Please sign in to comment.