Skip to content

Commit

Permalink
Fix spark test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 3, 2023
1 parent 1efab87 commit b851034
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,14 +1359,42 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
],
["features", "label", "qid"],
)
X_train = np.array(
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[9.0, 4.0, 8.0],
[np.NaN, 1.0, 5.5],
[np.NaN, 6.0, 7.5],
[np.NaN, 8.0, 9.5],
]
)
qid_train = np.array([0, 0, 0, 1, 1, 1])
y_train = np.array([0, 1, 2, 0, 1, 2])

X_test = np.array(
[
[1.5, 2.0, 3.0],
[4.5, 5.0, 6.0],
[9.0, 4.5, 8.0],
[np.NaN, 1.0, 6.0],
[np.NaN, 6.0, 7.0],
[np.NaN, 8.0, 10.5],
]
)

ltr = xgb.XGBRanker(tree_method="approx", objective="rank:pairwise")
ltr.fit(X_train, y_train, qid=qid_train)
predt = ltr.predict(X_test)

ranker_df_test = spark.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.75218),
(Vectors.dense(4.5, 5.0, 6.0), 0, -0.34192949533462524),
(Vectors.dense(9.0, 4.5, 8.0), 0, 1.7251298427581787),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.7521828413009644),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -1.0988065004348755),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 1.632674217224121),
(Vectors.dense(1.5, 2.0, 3.0), 0, float(predt[0])),
(Vectors.dense(4.5, 5.0, 6.0), 0, float(predt[1])),
(Vectors.dense(9.0, 4.5, 8.0), 0, float(predt[2])),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, float(predt[3])),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, float(predt[4])),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, float(predt[5])),
],
["features", "qid", "expected_prediction"],
)
Expand Down

0 comments on commit b851034

Please sign in to comment.