diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R index 67a7a045172e..8b4f0eeed037 100644 --- a/R-package/R/xgb.plot.multi.trees.R +++ b/R-package/R/xgb.plot.multi.trees.R @@ -23,6 +23,7 @@ #' @inheritParams xgb.plot.tree #' @param features_keep Number of features to keep in each position of the multi trees, #' by default 5. +#' @param render Should the graph be rendered or not? The default is `TRUE`. #' @inherit xgb.plot.tree return #' #' @examples diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index 881ea4fb3495..f5d53bb3432e 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -3,63 +3,38 @@ #' Read a tree model text dump and plot the model. #' #' @details -#' When using `style="xgboost"`, the content of each node is visualized as follows: -#' - For non-terminal nodes, it will display the split condition (number or name if -#' available, and the condition that would decide to which node to go next). -#' - Those nodes will be connected to their children by arrows that indicate whether the -#' branch corresponds to the condition being met or not being met. +#' The content of each node is visualized as follows: +#' - For non-terminal nodes, it will display the split condition (number or name +#' if available, and the condition that would decide to which node to go +#' next). +#' - Those nodes will be connected to their children by arrows that indicate +#' whether the branch corresponds to the condition being met or not being met. #' - Terminal (leaf) nodes contain the margin to add when ending there. #' -#' When using `style="R"`, the content of each node is visualized like this: -#' - *Feature name*. -#' - *Cover:* The sum of second order gradients of training data. -#' For the squared loss, this simply corresponds to the number of instances in the node. -#' The deeper in the tree, the lower the value. -#' - *Gain* (for split nodes): Information gain metric of a split -#' (corresponds to the importance of the node in the model). -#' - *Value* (for leaves): Margin value that the leaf may contribute to the prediction. -#' -#' The tree root nodes also indicate the tree index (0-based). -#' #' The "Yes" branches are marked by the "< split_value" label. #' The branches also used for missing values are marked as bold #' (as in "carrying extra capacity"). #' -#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR backend. +#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR +#' backend. #' -#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through -#' [setinfo()], they will be used in the output from this function. -#' @param trees An integer vector of tree indices that should be used. -#' The default (`NULL`) uses all trees. -#' Useful, e.g., in multiclass classification to get only -#' the trees of one class. *Important*: the tree index in XGBoost models -#' is zero-based (e.g., use `trees = 0:2` for the first three trees). +#' @param model Object of class `xgb.Booster`. If it contains feature names +#' (they can be set through [setinfo()], they will be used in the +#' output from this function. +#' @param tree_idx An integer of the tree index that should be used. This +#' is an 1-based index. #' @param plot_width,plot_height Width and height of the graph in pixels. #' The values are passed to `DiagrammeR::render_graph()`. -#' @param render Should the graph be rendered or not? The default is `TRUE`. -#' @param show_node_id a logical flag for whether to show node id's in the graph. -#' @param style Style to use for the plot: -#' - `"xgboost"`: will use the plot style defined in the core XGBoost library, -#' which is shared between different interfaces through the 'dot' format. This -#' style was not available before version 2.1.0 in R. It always plots the trees -#' vertically (from top to bottom). -#' - `"R"`: will use the style defined from XGBoost's R interface, which predates -#' the introducition of the standardized style from the core library. It might plot -#' the trees horizontally (from left to right). -#' -#' Note that `style="xgboost"` is only supported when all of the following conditions are met: -#' - Only a single tree is being plotted. -#' - Node IDs are not added to the graph. -#' - The graph is being returned as `htmlwidget` (`render=TRUE`). +#' @param with_stats Whether to dump some additional statistics about the +#' splits. When this option is on, the model dump contains two additional +#' values: gain is the approximate loss function gain we get in each split; +#' cover is the sum of second order gradient in each node. #' @param ... Currently not used. #' @return -#' The value depends on the `render` parameter: -#' - If `render = TRUE` (default): Rendered graph object which is an htmlwidget of -#' class `grViz`. Similar to "ggplot" objects, it needs to be printed when not -#' running from the command line. -#' - If `render = FALSE`: Graph object which is of DiagrammeR's class `dgr_graph`. -#' This could be useful if one wants to modify some of the graph attributes -#' before rendering the graph with `DiagrammeR::render_graph()`. +#' +#' Rendered graph object which is an htmlwidget of ' class `grViz`. Similar to +#' "ggplot" objects, it needs to be printed when not running from the command +#' line. #' #' @examples #' data(agaricus.train, package = "xgboost") @@ -73,119 +48,35 @@ #' objective = "binary:logistic" #' ) #' -#' # plot the first tree, using the style from xgboost's core library -#' # (this plot should look identical to the ones generated from other -#' # interfaces like the python package for xgboost) -#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost") -#' -#' # plot all the trees -#' xgb.plot.tree(model = bst, trees = NULL) +#' # plot the first tree +#' xgb.plot.tree(model = bst, tree_idx = 1) #' -#' # plot only the first tree and display the node ID: -#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE) #' #' \dontrun{ #' # Below is an example of how to save this plot to a file. -#' # Note that for export_graph() to work, the {DiagrammeRsvg} -#' # and {rsvg} packages must also be installed. #' #' library(DiagrammeR) #' -#' gr <- xgb.plot.tree(model = bst, trees = 0:1, render = FALSE) -#' export_graph(gr, "tree.pdf", width = 1500, height = 1900) -#' export_graph(gr, "tree.png", width = 1500, height = 1900) +#' gr <- xgb.plot.tree(model = bst, tree_idx = 1) +#' htmlwidgets::saveWidget(gr, 'plot.html') #' } #' #' @export -xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, - render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) { +xgb.plot.tree <- function(model, + tree_idx = 1, + plot_width = NULL, + plot_height = NULL, + with_stats = FALSE, ...) { check.deprecation(...) if (!inherits(model, "xgb.Booster")) { - stop("model: Has to be an object of class xgb.Booster") + stop("model has to be an object of the class xgb.Booster") } - if (!requireNamespace("DiagrammeR", quietly = TRUE)) { - stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE) - } - - style <- as.character(head(style, 1L)) - stopifnot(style %in% c("R", "xgboost")) - if (style == "xgboost") { - if (NROW(trees) != 1L || !render || show_node_id) { - stop("style='xgboost' is only supported for single, rendered tree, without node IDs.") - } - - txt <- xgb.dump(model, dump_format = "dot") - return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height)) - } - - dt <- xgb.model.dt.tree(model = model, trees = trees) - - dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)] - if (show_node_id) - dt[, label := paste0(ID, ": ", label)] - dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)] - dt[, shape := "rectangle"][Feature == "Leaf", shape := "oval"] - dt[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"] - # in order to draw the first tree on top: - dt <- dt[order(-Tree)] - - nodes <- DiagrammeR::create_node_df( - n = nrow(dt), - ID = dt$ID, - label = dt$label, - fillcolor = dt$filledcolor, - shape = dt$shape, - data = dt$Feature, - fontcolor = "black") - - if (nrow(dt[Feature != "Leaf"]) != 0) { - edges <- DiagrammeR::create_edge_df( - from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID), - to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID), - label = c( - dt[Feature != "Leaf", paste("<", Split)], - rep("", nrow(dt[Feature != "Leaf"])) - ), - style = c( - dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")], - dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")] - ), - rel = "leading_to") - } else { - edges <- NULL + stop("The DiagrammeR package is required for xgb.plot.tree", call. = FALSE) } - graph <- DiagrammeR::create_graph( - nodes_df = nodes, - edges_df = edges, - attr_theme = NULL - ) - graph <- DiagrammeR::add_global_graph_attrs( - graph = graph, - attr_type = "graph", - attr = c("layout", "rankdir"), - value = c("dot", "LR") + txt <- xgb.dump(model, dump_format = "dot", with_stats = with_stats) + DiagrammeR::grViz( + txt[[tree_idx]], width = plot_width, height = plot_height ) - graph <- DiagrammeR::add_global_graph_attrs( - graph = graph, - attr_type = "node", - attr = c("color", "style", "fontname"), - value = c("DimGray", "filled", "Helvetica") - ) - graph <- DiagrammeR::add_global_graph_attrs( - graph = graph, - attr_type = "edge", - attr = c("color", "arrowsize", "arrowhead", "fontname"), - value = c("DimGray", "1.5", "vee", "Helvetica") - ) - - if (!render) return(invisible(graph)) - - DiagrammeR::render_graph(graph, width = plot_width, height = plot_height) } - -# Avoid error messages during CRAN check. -# The reason is that these variables are never declared -# They are mainly column names inferred by Data.table... -globalVariables(c("Feature", "ID", "Cover", "Gain", "Split", "Yes", "No", "Missing", ".", "shape", "filledcolor", "label")) diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 4fb5c429c1a3..cafdde2da856 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -103,7 +103,7 @@ #' objective is non-convex. #' #' See the tutorials [Custom Objective and Evaluation Metric](https://xgboost.readthedocs.io/en/stable/tutorials/custom_metric_obj.html) -#' and [Advanced Usage of Custom Objectives](https://xgboost.readthedocs.io/en/stable/tutorials/advanced_custom_obj) +#' and [Advanced Usage of Custom Objectives](https://xgboost.readthedocs.io/en/latest/tutorials/advanced_custom_obj.html) #' for more information about custom objectives. #' #' - `base_score`: The initial prediction score of all instances, global bias. Default: 0.5. diff --git a/R-package/man/xgb.plot.multi.trees.Rd b/R-package/man/xgb.plot.multi.trees.Rd index 6b165f9af113..df72ee452ee6 100644 --- a/R-package/man/xgb.plot.multi.trees.Rd +++ b/R-package/man/xgb.plot.multi.trees.Rd @@ -14,8 +14,9 @@ xgb.plot.multi.trees( ) } \arguments{ -\item{model}{Object of class \code{xgb.Booster}. If it contains feature names (they can be set through -\code{\link[=setinfo]{setinfo()}}, they will be used in the output from this function.} +\item{model}{Object of class \code{xgb.Booster}. If it contains feature names +(they can be set through \code{\link[=setinfo]{setinfo()}}, they will be used in the +output from this function.} \item{features_keep}{Number of features to keep in each position of the multi trees, by default 5.} @@ -28,15 +29,9 @@ The values are passed to \code{DiagrammeR::render_graph()}.} \item{...}{Currently not used.} } \value{ -The value depends on the \code{render} parameter: -\itemize{ -\item If \code{render = TRUE} (default): Rendered graph object which is an htmlwidget of -class \code{grViz}. Similar to "ggplot" objects, it needs to be printed when not -running from the command line. -\item If \code{render = FALSE}: Graph object which is of DiagrammeR's class \code{dgr_graph}. -This could be useful if one wants to modify some of the graph attributes -before rendering the graph with \code{DiagrammeR::render_graph()}. -} +Rendered graph object which is an htmlwidget of ' class \code{grViz}. Similar to +"ggplot" objects, it needs to be printed when not running from the command +line. } \description{ Visualization of the ensemble of trees as a single collective unit. diff --git a/R-package/man/xgb.plot.tree.Rd b/R-package/man/xgb.plot.tree.Rd index cf0c2d0dd9eb..c58187d0f520 100644 --- a/R-package/man/xgb.plot.tree.Rd +++ b/R-package/man/xgb.plot.tree.Rd @@ -5,95 +5,57 @@ \title{Plot boosted trees} \usage{ xgb.plot.tree( - model = NULL, - trees = NULL, + model, + tree_idx = 1, plot_width = NULL, plot_height = NULL, - render = TRUE, - show_node_id = FALSE, - style = c("R", "xgboost"), + with_stats = FALSE, ... ) } \arguments{ -\item{model}{Object of class \code{xgb.Booster}. If it contains feature names (they can be set through -\code{\link[=setinfo]{setinfo()}}, they will be used in the output from this function.} +\item{model}{Object of class \code{xgb.Booster}. If it contains feature names +(they can be set through \code{\link[=setinfo]{setinfo()}}, they will be used in the +output from this function.} -\item{trees}{An integer vector of tree indices that should be used. -The default (\code{NULL}) uses all trees. -Useful, e.g., in multiclass classification to get only -the trees of one class. \emph{Important}: the tree index in XGBoost models -is zero-based (e.g., use \code{trees = 0:2} for the first three trees).} +\item{tree_idx}{An integer of the tree index that should be used. This +is an 1-based index.} \item{plot_width, plot_height}{Width and height of the graph in pixels. The values are passed to \code{DiagrammeR::render_graph()}.} -\item{render}{Should the graph be rendered or not? The default is \code{TRUE}.} - -\item{show_node_id}{a logical flag for whether to show node id's in the graph.} - -\item{style}{Style to use for the plot: -\itemize{ -\item \code{"xgboost"}: will use the plot style defined in the core XGBoost library, -which is shared between different interfaces through the 'dot' format. This -style was not available before version 2.1.0 in R. It always plots the trees -vertically (from top to bottom). -\item \code{"R"}: will use the style defined from XGBoost's R interface, which predates -the introducition of the standardized style from the core library. It might plot -the trees horizontally (from left to right). -} - -Note that \code{style="xgboost"} is only supported when all of the following conditions are met: -\itemize{ -\item Only a single tree is being plotted. -\item Node IDs are not added to the graph. -\item The graph is being returned as \code{htmlwidget} (\code{render=TRUE}). -}} +\item{with_stats}{Whether to dump some additional statistics about the +splits. When this option is on, the model dump contains two additional +values: gain is the approximate loss function gain we get in each split; +cover is the sum of second order gradient in each node.} \item{...}{Currently not used.} } \value{ -The value depends on the \code{render} parameter: -\itemize{ -\item If \code{render = TRUE} (default): Rendered graph object which is an htmlwidget of -class \code{grViz}. Similar to "ggplot" objects, it needs to be printed when not -running from the command line. -\item If \code{render = FALSE}: Graph object which is of DiagrammeR's class \code{dgr_graph}. -This could be useful if one wants to modify some of the graph attributes -before rendering the graph with \code{DiagrammeR::render_graph()}. -} +Rendered graph object which is an htmlwidget of ' class \code{grViz}. Similar to +"ggplot" objects, it needs to be printed when not running from the command +line. } \description{ Read a tree model text dump and plot the model. } \details{ -When using \code{style="xgboost"}, the content of each node is visualized as follows: +The content of each node is visualized as follows: \itemize{ -\item For non-terminal nodes, it will display the split condition (number or name if -available, and the condition that would decide to which node to go next). -\item Those nodes will be connected to their children by arrows that indicate whether the -branch corresponds to the condition being met or not being met. +\item For non-terminal nodes, it will display the split condition (number or name +if available, and the condition that would decide to which node to go +next). +\item Those nodes will be connected to their children by arrows that indicate +whether the branch corresponds to the condition being met or not being met. \item Terminal (leaf) nodes contain the margin to add when ending there. } -When using \code{style="R"}, the content of each node is visualized like this: -\itemize{ -\item \emph{Feature name}. -\item \emph{Cover:} The sum of second order gradients of training data. -For the squared loss, this simply corresponds to the number of instances in the node. -The deeper in the tree, the lower the value. -\item \emph{Gain} (for split nodes): Information gain metric of a split -(corresponds to the importance of the node in the model). -\item \emph{Value} (for leaves): Margin value that the leaf may contribute to the prediction. -} - -The tree root nodes also indicate the tree index (0-based). - The "Yes" branches are marked by the "< split_value" label. The branches also used for missing values are marked as bold (as in "carrying extra capacity"). -This function uses \href{https://www.graphviz.org/}{GraphViz} as DiagrammeR backend. +This function uses \href{https://www.graphviz.org/}{GraphViz} as DiagrammeR +backend. } \examples{ data(agaricus.train, package = "xgboost") @@ -107,27 +69,17 @@ bst <- xgb.train( objective = "binary:logistic" ) -# plot the first tree, using the style from xgboost's core library -# (this plot should look identical to the ones generated from other -# interfaces like the python package for xgboost) -xgb.plot.tree(model = bst, trees = 1, style = "xgboost") - -# plot all the trees -xgb.plot.tree(model = bst, trees = NULL) +# plot the first tree +xgb.plot.tree(model = bst, tree_idx = 1) -# plot only the first tree and display the node ID: -xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE) \dontrun{ # Below is an example of how to save this plot to a file. -# Note that for export_graph() to work, the {DiagrammeRsvg} -# and {rsvg} packages must also be installed. library(DiagrammeR) -gr <- xgb.plot.tree(model = bst, trees = 0:1, render = FALSE) -export_graph(gr, "tree.pdf", width = 1500, height = 1900) -export_graph(gr, "tree.png", width = 1500, height = 1900) +gr <- xgb.plot.tree(model = bst, tree_idx = 1) +htmlwidgets::saveWidget(gr, 'plot.html') } } diff --git a/R-package/man/xgb.train.Rd b/R-package/man/xgb.train.Rd index cbf7abf0f48b..be4290d9806d 100644 --- a/R-package/man/xgb.train.Rd +++ b/R-package/man/xgb.train.Rd @@ -129,7 +129,7 @@ the Hessian will be clipped, so one might consider using the expected Hessian (F objective is non-convex. See the tutorials \href{https://xgboost.readthedocs.io/en/stable/tutorials/custom_metric_obj.html}{Custom Objective and Evaluation Metric} -and \href{https://xgboost.readthedocs.io/en/stable/tutorials/advanced_custom_obj}{Advanced Usage of Custom Objectives} +and \href{https://xgboost.readthedocs.io/en/latest/tutorials/advanced_custom_obj.html}{Advanced Usage of Custom Objectives} for more information about custom objectives. \item \code{base_score}: The initial prediction score of all instances, global bias. Default: 0.5. \item \code{eval_metric}: Evaluation metrics for validation data. diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 8ddba9519fc6..bfffe9e7878c 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -408,6 +408,19 @@ test_that("xgb.plot.tree works with and without feature names", { .skip_if_vcd_not_available() expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree.unnamed)) expect_silent(xgb.plot.tree(model = bst.Tree)) + + ## Categorical + y <- rnorm(100) + x <- sample(3, size = 100 * 3, replace = TRUE) |> matrix(nrow = 100) + x <- x - 1 + dm <- xgb.DMatrix(data = x, label = y) + setinfo(dm, "feature_type", c("c", "c", "c")) + model <- xgb.train( + data = dm, + params = list(tree_method = "hist"), + nrounds = 2 + ) + expect_silent(xgb.plot.tree(model = model)) }) test_that("xgb.plot.multi.trees works with and without feature names", { diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 05c0cc30fa82..c2034652322d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2994,7 +2994,7 @@ def get_dump( fmap : Name of the file containing feature map names. with_stats : - Controls whether the split statistics are output. + Controls whether the split statistics should be included. dump_format : Format of model dump. Can be 'text', 'json' or 'dot'. diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index ebb06b2afb11..552261c2d1a0 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -2,6 +2,7 @@ # pylint: disable=too-many-branches """Plotting Library.""" import json +import warnings from io import BytesIO from typing import Any, Optional, Union @@ -153,12 +154,14 @@ def to_graphviz( booster: Union[Booster, XGBModel], *, fmap: PathLike = "", - num_trees: int = 0, + num_trees: Optional[int] = None, rankdir: Optional[str] = None, yes_color: Optional[str] = None, no_color: Optional[str] = None, condition_node_params: Optional[dict] = None, leaf_node_params: Optional[dict] = None, + with_stats: bool = False, + tree_idx: int = 0, **kwargs: Any, ) -> GraphvizSource: """Convert specified tree to graphviz instance. IPython can automatically plot @@ -172,7 +175,11 @@ def to_graphviz( fmap : The name of feature map file num_trees : + + .. deprecated:: 3.0 + Specify the ordinal number of target tree + rankdir : Passed to graphviz via graph_attr yes_color : @@ -197,6 +204,18 @@ def to_graphviz( 'style': 'filled', 'fillcolor': '#e48038'} + with_stats : + + .. versionadded:: 3.0 + + Controls whether the split statistics should be included. + + tree_idx : + + .. versionadded:: 3.0 + + Specify the ordinal index of target tree. + kwargs : Other keywords passed to graphviz graph_attr, e.g. ``graph [ {key} = {value} ]`` @@ -243,35 +262,68 @@ def to_graphviz( if kwargs: parameters += ":" parameters += json.dumps(kwargs) - tree = booster.get_dump(fmap=fmap, dump_format=parameters)[num_trees] + + if num_trees is not None: + warnings.warn( + "The `num_trees` parameter is deprecated, use `tree_idx` insetad. ", + FutureWarning, + ) + if tree_idx not in (0, num_trees): + raise ValueError( + "Both `num_trees` and `tree_idx` are used, prefer `tree_idx` instead." + ) + tree_idx = num_trees + + tree = booster.get_dump(fmap=fmap, dump_format=parameters, with_stats=with_stats)[ + tree_idx + ] g = Source(tree) return g +@_deprecate_positional_args def plot_tree( - booster: Booster, + booster: Union[Booster, XGBModel], + *, fmap: PathLike = "", - num_trees: int = 0, + num_trees: Optional[int] = None, rankdir: Optional[str] = None, ax: Optional[Axes] = None, + with_stats: bool = False, + tree_idx: int = 0, **kwargs: Any, ) -> Axes: """Plot specified tree. Parameters ---------- - booster : Booster, XGBModel + booster : Booster or XGBModel instance fmap: str (optional) The name of feature map file - num_trees : int, default 0 - Specify the ordinal number of target tree + num_trees : + + .. deprecated:: 3.0 + rankdir : str, default "TB" Passed to graphviz via graph_attr ax : matplotlib Axes, default None Target axes instance. If None, new figure and axes will be created. + + with_stats : + + .. versionadded:: 3.0 + + See :py:func:`to_graphviz`. + + tree_idx : + + .. versionadded:: 3.0 + + See :py:func:`to_graphviz`. + kwargs : - Other keywords passed to to_graphviz + Other keywords passed to :py:func:`to_graphviz` Returns ------- @@ -287,7 +339,15 @@ def plot_tree( if ax is None: _, ax = plt.subplots(1, 1) - g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs) + g = to_graphviz( + booster, + fmap=fmap, + num_trees=num_trees, + rankdir=rankdir, + with_stats=with_stats, + tree_idx=tree_idx, + **kwargs, + ) s = BytesIO() s.write(g.pipe(format="png")) diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 2a5a40b970a1..040022c373a4 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -82,11 +82,11 @@ class TreeGenerator { } return res; } - /* \brief Find the first occurrence of key in input and replace it with corresponding + /* @brief Find the first occurrence of key in input and replace it with corresponding * value. */ - static std::string Match(std::string const& input, - std::map const& replacements) { + [[nodiscard]] static std::string Match(std::string const& input, + std::map const& replacements) { std::string result = input; for (auto const& kv : replacements) { auto pos = result.find(kv.first); @@ -671,16 +671,31 @@ class GraphvizGenerator : public TreeGenerator { std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override { auto split_index = tree.SplitIndex(nidx); auto cond = tree.SplitCond(nidx); - static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; + static std::string const kNodeTemplate = + " {nid} [ label=\"{fname}{<}{cond}{stat}\" {params}]\n"; bool has_less = (split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator; - std::string result = - SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)}, - {"{fname}", GetFeatureName(fmap_, split_index)}, - {"{<}", has_less ? "<" : ""}, - {"{cond}", has_less ? ToStr(cond) : ""}, - {"{params}", param_.condition_node_params}}); + std::string result; + if (this->with_stats_) { + CHECK(!tree.IsMultiTarget()) << MTNotImplemented(); + result = SuperT::Match( + kNodeTemplate, {{"{nid}", std::to_string(nidx)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{<}", has_less ? "<" : ""}, + {"{cond}", has_less ? ToStr(cond) : ""}, + {"{stat}", Match("\ncover={cover}\ngain={gain}", + {{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)}, + {"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}})}, + {"{params}", param_.condition_node_params}}); + } else { + result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{<}", has_less ? "<" : ""}, + {"{cond}", has_less ? ToStr(cond) : ""}, + {"{stat}", ""}, + {"{params}", param_.condition_node_params}}); + } result += BuildEdge(tree, nidx, tree.LeftChild(nidx), true); result += BuildEdge(tree, nidx, tree.RightChild(nidx), false); @@ -708,21 +723,31 @@ class GraphvizGenerator : public TreeGenerator { } std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override { - static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; - // hardcoded limit to avoid dumping long arrays into dot graph. - bst_target_t constexpr kLimit{3}; - if (tree.IsMultiTarget()) { - auto value = tree.GetMultiTargetTree()->LeafValue(nidx); - auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, - {"{leaf-value}", ToStr(value, kLimit)}, - {"{params}", param_.leaf_node_params}}); - return result; + static std::string const kCoverTemplate = "\ncover={cover}"; + static std::string const kLeafTemplate = + " {nid} [ label=\"leaf={leaf-value}{cover}\" {params}]\n"; + auto plot = [&](std::string cover) { + if (tree.IsMultiTarget()) { + auto value = tree.GetMultiTargetTree()->LeafValue(nidx); + // Hardcoded limit to avoid dumping long arrays into dot graph. + bst_target_t constexpr kLimit{3}; + return SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, + {"{leaf-value}", ToStr(value, kLimit)}, + {"{cover}", std::move(cover)}, + {"{params}", param_.leaf_node_params}}); + } else { + auto value = tree[nidx].LeafValue(); + return SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, + {"{leaf-value}", ToStr(value)}, + {"{cover}", std::move(cover)}, + {"{params}", param_.leaf_node_params}}); + } + }; + if (this->with_stats_) { + CHECK(!tree.IsMultiTarget()) << MTNotImplemented(); + return plot(SuperT::Match(kCoverTemplate, {{"{cover}", ToStr(tree.Stat(nidx).sum_hess)}})); } else { - auto value = tree[nidx].LeafValue(); - auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, - {"{leaf-value}", ToStr(value)}, - {"{params}", param_.leaf_node_params}}); - return result; + return plot(""); } } diff --git a/tests/cpp/tree/test_multi_target_tree_model.cc b/tests/cpp/tree/test_multi_target_tree_model.cc index 39e4cb4b52f0..2f01e05de0e2 100644 --- a/tests/cpp/tree/test_multi_target_tree_model.cc +++ b/tests/cpp/tree/test_multi_target_tree_model.cc @@ -60,7 +60,7 @@ TEST(MultiTargetTree, DumpDot) { auto name = "feat_" + std::to_string(f); fmap.PushBack(f, name.c_str(), "q"); } - auto str = tree->DumpModel(fmap, true, "dot"); + auto str = tree->DumpModel(fmap, false, "dot"); ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos); ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos); @@ -71,7 +71,7 @@ TEST(MultiTargetTree, DumpDot) { linalg::Vector weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()}; tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(), weight.HostView(), weight.HostView()); - auto str = tree.DumpModel(fmap, true, "dot"); + auto str = tree.DumpModel(fmap, false, "dot"); ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos); } } diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 2dc1893dd645..941c425bd9b0 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -1,11 +1,12 @@ -// Copyright by Contributors +/** + * Copyright 2018-2024, XGBoost Contributors + */ #include #include "../../../src/common/bitfield.h" #include "../../../src/common/categorical.h" #include "../filesystem.h" #include "../helpers.h" -#include "xgboost/json_io.h" #include "xgboost/tree_model.h" namespace xgboost { @@ -449,7 +450,8 @@ TEST(Tree, DumpDot) { fmap.PushBack(2, "feat_2", "int"); str = tree.DumpModel(fmap, true, "dot"); - ASSERT_NE(str.find(R"("feat_0")"), std::string::npos); + ASSERT_NE(str.find(R"("feat_0)"), std::string::npos); + ASSERT_EQ(str.find(R"("feat_0")"), std::string::npos); // newline ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos); ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos); diff --git a/tests/python/test_plotting.py b/tests/python/test_plotting.py index aaf896c6945f..1e1311c5750f 100644 --- a/tests/python/test_plotting.py +++ b/tests/python/test_plotting.py @@ -54,10 +54,10 @@ def test_plotting(self): assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue - g = xgb.to_graphviz(booster, num_trees=0) + g = xgb.to_graphviz(booster, tree_idx=0) assert isinstance(g, Source) - ax = xgb.plot_tree(booster, num_trees=0) + ax = xgb.plot_tree(booster, tree_idx=0) assert isinstance(ax, Axes) def test_importance_plot_lim(self): @@ -86,9 +86,9 @@ def run_categorical(self, tree_method: str) -> None: j_tree["split_condition"], list ) - graph = xgb.to_graphviz(reg, num_trees=len(j_tree) - 1) + graph = xgb.to_graphviz(reg, tree_idx=len(j_tree) - 1) assert isinstance(graph, Source) - ax = xgb.plot_tree(reg, num_trees=len(j_tree) - 1) + ax = xgb.plot_tree(reg, tree_idx=len(j_tree) - 1) assert isinstance(ax, Axes) @pytest.mark.skipif(**tm.no_pandas())