Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion>

<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<artifactId>xgboost-jvm_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
<packaging>pom</packaging>
<name>XGBoost JVM Package</name>
Expand Down Expand Up @@ -43,15 +43,15 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<flink.version>1.20.0</flink.version>
<junit.version>4.13.2</junit.version>
<spark.version>3.5.3</spark.version>
<spark.version.gpu>3.5.1</spark.version.gpu>
<spark.version>4.0.1-SNAPSHOT</spark.version>
<spark.version.gpu>4.0.1-SNAPSHOT</spark.version.gpu>
<fasterxml.jackson.version>2.15.0</fasterxml.jackson.version>
<scala.version>2.12.18</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<scala.version>2.13.11</scala.version>
<scala.binary.version>2.13</scala.binary.version>
<hadoop.version>3.4.1</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
Expand Down Expand Up @@ -89,6 +89,17 @@
<name>central maven</name>
<url>https://repo1.maven.org/maven2</url>
</repository>
<repository>
<id>apache-snapshots</id>
<url>https://repository.apache.org/content/repositories/snapshots/</url>
<snapshots>
<enabled>true</enabled>
</snapshots>
<releases>
<enabled>false</enabled>
</releases>
</repository>

</repositories>
<modules>
</modules>
Expand Down
8 changes: 4 additions & 4 deletions jvm-packages/xgboost4j-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<artifactId>xgboost-jvm_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
</parent>
<name>xgboost4j-example</name>
<artifactId>xgboost4j-example_2.12</artifactId>
<artifactId>xgboost4j-example_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<build>
Expand All @@ -26,7 +26,7 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_2.12</artifactId>
<artifactId>xgboost4j-spark_2.13</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
Expand All @@ -37,7 +37,7 @@
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_2.12</artifactId>
<artifactId>xgboost4j-flink_2.13</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
Expand Down
6 changes: 3 additions & 3 deletions jvm-packages/xgboost4j-flink/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<artifactId>xgboost-jvm_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
</parent>

<name>xgboost4j-flink</name>
<artifactId>xgboost4j-flink_2.12</artifactId>
<artifactId>xgboost4j-flink_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
<properties>
<flink-ml.version>2.2.0</flink-ml.version>
Expand All @@ -30,7 +30,7 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_2.12</artifactId>
<artifactId>xgboost4j_2.13</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
Expand Down
10 changes: 5 additions & 5 deletions jvm-packages/xgboost4j-spark-gpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<artifactId>xgboost-jvm_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
</parent>
<name>xgboost4j-spark-gpu</name>
<artifactId>xgboost4j-spark-gpu_2.12</artifactId>
<artifactId>xgboost4j-spark-gpu_2.13</artifactId>
<description>JVM Package for XGBoost</description>
<url>https://github.com/dmlc/xgboost/tree/master/jvm-packages</url>
<licenses>
Expand Down Expand Up @@ -78,17 +78,17 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_2.12</artifactId>
<artifactId>xgboost4j_2.13</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_2.12</artifactId>
<artifactId>xgboost4j-spark_2.13</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_2.12</artifactId>
<artifactId>xgboost4j_2.13</artifactId>
</exclusion>
</exclusions>
</dependency>
Expand Down
6 changes: 3 additions & 3 deletions jvm-packages/xgboost4j-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<artifactId>xgboost-jvm_2.13</artifactId>
<version>3.1.0-SNAPSHOT</version>
</parent>
<name>xgboost4j-spark</name>
<artifactId>xgboost4j-spark_2.12</artifactId>
<artifactId>xgboost4j-spark_2.13</artifactId>
<build>
<plugins>
<plugin>
Expand Down Expand Up @@ -46,7 +46,7 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_2.12</artifactId>
<artifactId>xgboost4j_2.13</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
Expand Down
50 changes: 50 additions & 0 deletions jvm-packages/xgboost4j-spark/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

[project]
name = "xgboost4j"
version = "3.1.0"
authors = [
{ name = "Bobby Wang", email = "[email protected]" },
]
description = "XGBoost4j-Spark pyspark"
readme = "README.md"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Environment :: GPU :: NVIDIA CUDA :: 11",
"Environment :: GPU :: NVIDIA CUDA :: 11.4",
"Environment :: GPU :: NVIDIA CUDA :: 11.5",
"Environment :: GPU :: NVIDIA CUDA :: 11.6",
"Environment :: GPU :: NVIDIA CUDA :: 11.7",
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.2",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.3",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.4",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.5",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.6",
Comment on lines +31 to +44
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be removed, we dropped suport for all previous versions in 3.0

"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.8",
]

[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys

import xgboost4j

sys.modules["ml.dmlc.xgboost4j.scala.spark"] = xgboost4j
5 changes: 5 additions & 0 deletions jvm-packages/xgboost4j-spark/python/src/xgboost4j/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .estimator import XGBoostClassificationModel, XGBoostClassifier

__version__ = "3.0.0"

__all__ = ["XGBoostClassifier", "XGBoostClassificationModel"]
103 changes: 103 additions & 0 deletions jvm-packages/xgboost4j-spark/python/src/xgboost4j/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Union, List, Any, Optional, Dict

from pyspark import keyword_only
from pyspark.ml.classification import _JavaProbabilisticClassifier, _JavaProbabilisticClassificationModel

from .params import XGBoostParams


class XGBoostClassifier(_JavaProbabilisticClassifier["XGBoostClassificationModel"], XGBoostParams):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be a new class that needs to be used for both normal and spark connect invocation ? Can we not modify _fit method in SparkXGBClassifier to use try_remote_fit decorator or will it be a big change ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx @jjayadeep06 for your reply. Well, when involving connect, things is becoming complicated. You know, we could make the existing xgboost-pyspark support spark connect by changing the RDD operations to Dataframe without using any try_remote_xxxx. Yes, we have a plan to do that.

While this PR is to make xgboost jvm package to support connect by introducing a light-weight python wrapper. If we add the python wrapper over xgboost jvm package to the existing xgboost-python-pyspark, then it's going to raise an issue which backends (xgboost jvm package or python package) will be chose when running xgboost over connect?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got It

_input_kwargs: Dict[str, Any]

@keyword_only
def __init__(
self,
*,
featuresCol: Union[str, List[str]] = "features",
labelCol: str = "label",
predictionCol: str = "prediction",
probabilityCol: str = "probability",
rawPredictionCol: str = "rawPrediction",
# SparkParams
numWorkers: Optional[int] = None,
numRound: Optional[int] = None,
forceRepartition: Optional[bool] = None,
numEarlyStoppingRounds: Optional[int] = None,
inferBatchSize: Optional[int] = None,
missing: Optional[float] = None,
useExternalMemory: Optional[bool] = None,
maxNumDevicePages: Optional[int] = None,
maxQuantileBatches: Optional[int] = None,
minCachePageBytes: Optional[int] = None,
feature_names: Optional[List[str]] = None,
feature_types: Optional[List[str]] = None,
# RabitParams
rabitTrackerTimeout: Optional[int] = None,
rabitTrackerHostIp: Optional[str] = None,
rabitTrackerPort: Optional[int] = None,
# GeneralParams
booster: Optional[str] = None,
device: Optional[str] = None,
verbosity: Optional[int] = None,
validate_parameters: Optional[bool] = None,
nthread: Optional[int] = None,
# TreeBoosterParams
eta: Optional[float] = None,
gamma: Optional[float] = None,
max_depth: Optional[int] = None,
min_child_weight: Optional[float] = None,
max_delta_step: Optional[float] = None,
subsample: Optional[float] = None,
sampling_method: Optional[str] = None,
colsample_bytree: Optional[float] = None,
colsample_bylevel: Optional[float] = None,
colsample_bynode: Optional[float] = None,
reg_lambda: Optional[float] = None,
alpha: Optional[float] = None,
tree_method: Optional[str] = None,
scale_pos_weight: Optional[float] = None,
updater: Optional[str] = None,
refresh_leaf: Optional[bool] = None,
process_type: Optional[str] = None,
grow_policy: Optional[str] = None,
max_leaves: Optional[int] = None,
max_bin: Optional[int] = None,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[List[int]] = None,
interaction_constraints: Optional[str] = None,
max_cached_hist_node: Optional[int] = None,
# LearningTaskParams
objective: Optional[str] = None,
num_class: Optional[int] = None,
base_score: Optional[float] = None,
eval_metric: Optional[str] = None,
seed: Optional[int] = None,
seed_per_iteration: Optional[bool] = None,
tweedie_variance_power: Optional[float] = None,
huber_slope: Optional[float] = None,
aft_loss_distribution: Optional[str] = None,
lambdarank_pair_method: Optional[str] = None,
lambdarank_num_pair_per_sample: Optional[int] = None,
lambdarank_unbiased: Optional[bool] = None,
lambdarank_bias_norm: Optional[float] = None,
ndcg_exp_gain: Optional[bool] = None,
# DartBoosterParams
sample_type: Optional[str] = None,
normalize_type: Optional[str] = None,
rate_drop: Optional[float] = None,
one_drop: Optional[bool] = None,
skip_drop: Optional[float] = None,
**kwargs: Any,
):
super().__init__()
self._java_obj = self._new_java_obj(
"ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier", self.uid
)
self._set_params(**self._input_kwargs)

def _create_model(self, java_model: "JavaObject") -> "XGBoostClassificationModel":
return XGBoostClassificationModel(java_model)


class XGBoostClassificationModel(_JavaProbabilisticClassificationModel, XGBoostParams):
pass
Loading
Loading