Skip to content

Commit

Permalink
Mixin.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 15, 2024
1 parent b8ff1e0 commit dfa0c41
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
16 changes: 13 additions & 3 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,18 @@ def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
)


class QuantileDMatrix(DMatrix):
class _RefMixIn:
@property
def ref(self) -> Optional[weakref.ReferenceType]:
"""Internal method for retrieving a reference to the training DMatrix."""
return self._ref

@ref.setter
def ref(self, ref: weakref.ReferenceType) -> None:
self._ref = ref


class QuantileDMatrix(DMatrix, _RefMixIn):
"""A DMatrix variant that generates quantilized data directly from input for the
``hist`` tree method. This DMatrix is primarily designed to save memory in training
by avoiding intermediate storage. Set ``max_bin`` to control the number of bins
Expand Down Expand Up @@ -1644,7 +1655,7 @@ def _init(
self.ref = weakref.ref(ref)


class ExtMemQuantileDMatrix(DMatrix):
class ExtMemQuantileDMatrix(DMatrix, _RefMixIn):
"""The external memory version of the :py:class:`QuantileDMatrix`.
See :doc:`/tutorials/external_memory` for explanation and usage examples, and
Expand Down Expand Up @@ -1742,7 +1753,6 @@ def _init(
_check_call(ret)
self.handle = handle

self._ref: Optional[weakref.ReferenceType] = None
if ref is not None:
self.ref = weakref.ref(ref)

Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Training Library containing training routines."""
import copy
import os
import weakref
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -146,9 +147,13 @@ def train(

callbacks = [] if callbacks is None else copy.copy(list(callbacks))
evals = list(evals) if evals else []

for va, _ in evals:
if not isinstance(va, DMatrix):
raise TypeError("Invalid type for the `evals`,")

if hasattr(va, "ref") and va.ref is not None:
if va is not dtrain and va.ref is not dtrain:
if va is not dtrain and va.ref is not weakref.ref(dtrain):
raise ValueError(
"Training dataset should be used as a reference when constructing "
"the `QuantileDMatrix` for evaluation."
Expand Down

0 comments on commit dfa0c41

Please sign in to comment.