diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 040022c373a4..0639233510f7 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -303,9 +303,8 @@ class TextGenerator : public TreeGenerator { return result; } - std::string SplitNodeImpl( - RegTree const& tree, int32_t nid, std::string const& template_str, - std::string cond, uint32_t depth) const { + std::string SplitNodeImpl(RegTree const& tree, bst_node_t nid, std::string const& template_str, + std::string cond, uint32_t depth) const { auto split_index = tree[nid].SplitIndex(); std::string const result = SuperT::Match( template_str, @@ -345,18 +344,16 @@ class TextGenerator : public TreeGenerator { return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth); } - std::string Categorical(RegTree const &tree, int32_t nid, - uint32_t depth) const override { + std::string Categorical(RegTree const& tree, bst_node_t nid, uint32_t depth) const override { auto cats = GetSplitCategories(tree, nid); std::string cats_str = PrintCatsAsSet(cats); static std::string const kNodeTemplate = "{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}"; - std::string const result = - SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth); + std::string const result = SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth); return result; } - std::string NodeStat(RegTree const& tree, int32_t nid) const override { + std::string NodeStat(RegTree const& tree, bst_node_t nid) const override { static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; std::string const result = SuperT::Match( kStatTemplate, @@ -679,15 +676,12 @@ class GraphvizGenerator : public TreeGenerator { std::string result; if (this->with_stats_) { CHECK(!tree.IsMultiTarget()) << MTNotImplemented(); - result = SuperT::Match( - kNodeTemplate, {{"{nid}", std::to_string(nidx)}, - {"{fname}", GetFeatureName(fmap_, split_index)}, - {"{<}", has_less ? "<" : ""}, - {"{cond}", has_less ? ToStr(cond) : ""}, - {"{stat}", Match("\ncover={cover}\ngain={gain}", - {{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)}, - {"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}})}, - {"{params}", param_.condition_node_params}}); + result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{<}", has_less ? "<" : ""}, + {"{cond}", has_less ? ToStr(cond) : ""}, + {"{stat}", this->NodeStat(tree, nidx)}, + {"{params}", param_.condition_node_params}}); } else { result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)}, {"{fname}", GetFeatureName(fmap_, split_index)}, @@ -703,9 +697,15 @@ class GraphvizGenerator : public TreeGenerator { return result; }; - std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override { + std::string NodeStat(RegTree const& tree, bst_node_t nidx) const override { + return Match("\ngain={gain}\ncover={cover}", + {{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)}, + {"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}}); + } + + std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t /*depth*/) const override { static std::string const kLabelTemplate = - " {nid} [ label=\"{fname}:{cond}\" {params}]\n"; + " {nid} [ label=\"{fname}:{cond}{stat}\" {params}]\n"; auto cats = GetSplitCategories(tree, nidx); auto cats_str = PrintCatsAsSet(cats); auto split_index = tree.SplitIndex(nidx); @@ -714,6 +714,7 @@ class GraphvizGenerator : public TreeGenerator { SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)}, {"{fname}", GetFeatureName(fmap_, split_index)}, {"{cond}", cats_str}, + {"{stat}", this->NodeStat(tree, nidx)}, {"{params}", param_.condition_node_params}}); result += BuildEdge(tree, nidx, tree.LeftChild(nidx), true); diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 941c425bd9b0..2491f3973f9a 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -340,6 +340,7 @@ void TestCategoricalTreeDump(std::string format, std::string sep) { ASSERT_NE(pos, std::string::npos); pos = str.find(cond_str, pos + 1); ASSERT_NE(pos, std::string::npos); + ASSERT_NE(str.find("gain"), std::string::npos); if (format == "json") { // Make sure it's valid JSON