Skip to content

Commit

Permalink
Compatibility for xgboost>=1.7.0, fix master CI (#242)
Browse files Browse the repository at this point in the history
* Update requirements-test.txt

Signed-off-by: Antoni Baum <[email protected]>

* Update requirements-test.txt

Signed-off-by: Antoni Baum <[email protected]>

* Update requirements-test.txt

Signed-off-by: Antoni Baum <[email protected]>

* Update requirements-test.txt

Signed-off-by: Antoni Baum <[email protected]>

* Compatibility for xgboost 1.7.0

Signed-off-by: Antoni Baum <[email protected]>

* Fix MRO

Signed-off-by: Antoni Baum <[email protected]>

Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored Oct 31, 2022
1 parent d0647bc commit dd51311
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
3 changes: 2 additions & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ packaging
petastorm
pytest
pyarrow
ray[tune]
ray[tune, data]
scikit-learn
modin
dask

#workaround for now
protobuf<4.0.0
tensorboardX==2.2
51 changes: 38 additions & 13 deletions xgboost_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@
class EarlyStopException(XGBoostError):
pass


# From xgboost>=1.7.0, rabit is replaced by a collective communicator
try:
from xgboost.collective import CommunicatorContext
rabit = None
HAS_COLLECTIVE = True
except ImportError:
from xgboost import rabit # noqa
CommunicatorContext = None
HAS_COLLECTIVE = False

from xgboost_ray.callback import DistributedCallback, \
DistributedCallbackContainer
from xgboost_ray.compat import TrainingCallback, RabitTracker, LEGACY_CALLBACK
Expand Down Expand Up @@ -66,7 +77,7 @@ def inner_f(*args, **kwargs):
RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes, \
LEGACY_MATRIX
from xgboost_ray.session import init_session, put_queue, \
set_session_queue
set_session_queue, get_rabit_rank


def _get_environ(item: str, old_val: Any):
Expand Down Expand Up @@ -237,25 +248,40 @@ def _stop_rabit_tracker(rabit_process: multiprocessing.Process):
rabit_process.terminate()


class _RabitContext:
class _RabitContextBase:
"""This context is used by local training actors to connect to the
Rabit tracker.
Args:
actor_id (str): Unique actor ID
args (list): Arguments for Rabit initialisation. These are
args (dict): Arguments for Rabit initialisation. These are
environment variables to configure Rabit clients.
"""

def __init__(self, actor_id, args):
def __init__(self, actor_id: int, args: dict):
args["DMLC_TASK_ID"] = "[xgboost.ray]:" + actor_id
self.args = args
self.args.append(("DMLC_TASK_ID=[xgboost.ray]:" + actor_id).encode())

def __enter__(self):
xgb.rabit.init(self.args)

def __exit__(self, *args):
xgb.rabit.finalize()
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
if HAS_COLLECTIVE:

class _RabitContext(_RabitContextBase, CommunicatorContext):
pass

else:

class _RabitContext(_RabitContextBase):
def __init__(self, actor_id: int, args: dict):
super().__init__(actor_id, args)
self._list_args = [("%s=%s" % item).encode()
for item in self.args.items()]

def __enter__(self):
xgb.rabit.init(self._list_args)

def __exit__(self, *args):
xgb.rabit.finalize()


def _ray_get_actor_cpus():
Expand Down Expand Up @@ -517,12 +543,12 @@ def _save_checkpoint_callback(self):

class _SaveInternalCheckpointCallback(TrainingCallback):
def after_iteration(self, model, epoch, evals_log):
if xgb.rabit.get_rank() == 0 and \
if get_rabit_rank() == 0 and \
epoch % this.checkpoint_frequency == 0:
put_queue(_Checkpoint(epoch, pickle.dumps(model)))

def after_training(self, model):
if xgb.rabit.get_rank() == 0:
if get_rabit_rank() == 0:
put_queue(_Checkpoint(-1, pickle.dumps(model)))
return model

Expand Down Expand Up @@ -1054,8 +1080,7 @@ def handle_actor_failure(actor_id):
maybe_log("[RayXGBoost] Starting XGBoost training.")

# Start Rabit tracker for gradient sharing
rabit_process, env = _start_rabit_tracker(alive_actors)
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
rabit_process, rabit_args = _start_rabit_tracker(alive_actors)

# Load checkpoint if we have one. In that case we need to adjust the
# number of training rounds.
Expand Down
6 changes: 5 additions & 1 deletion xgboost_ray/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def get_actor_rank() -> int:
@PublicAPI
def get_rabit_rank() -> int:
import xgboost as xgb
return xgb.rabit.get_rank()
try:
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
return xgb.collective.get_rank()
except (ImportError, AttributeError):
return xgb.rabit.get_rank()


@PublicAPI
Expand Down

0 comments on commit dd51311

Please sign in to comment.