Trees are an important data structure in mlpack and are used in a number of the machine learning algorithms that mlpack implements. Often, the use of trees can allow significant acceleration of an algorithm; this is generally done by pruning away large parts of the tree during computation.
Most mlpack algorithms that use trees are not tied to a specific tree but
instead allow the user to choose a tree via the TreeType
template parameter.
Any tree passed as a TreeType
template parameter will need to implement a
certain set of functions. In addition, a tree may optionally specify some
traits about itself with the TreeTraits
trait class.
This document aims to clarify the abstractions underlying mlpack trees, list and
describe the required functionality of the TreeType
policy, and point users
towards existing types of trees.
Although this document is long, there may still be errors and unclear areas. If you are having trouble understanding anything, please get in touch on Github or on the mailing list and someone will help you (and possibly update the documentation afterwards).
In mlpack, we assume that we have some sort of data matrix, which might be
sparse or dense (that is, it could be of type arma::mat
or arma::sp_mat
,
or any variant that implements the Armadillo API). This data matrix corresponds
to a collection of points in some space (usually a Euclidean space). A tree is
a way of organizing this data matrix in a hierarchical manner---so, points that
are nearby should lie in similar nodes.
We can rigorously define what a tree is, using the definition of space tree introduced in the following paper:
R.R. Curtin, W.B. March, P. Ram, D.V. Anderson, A.G. Gray, and C.L. Isbell Jr.,
"Tree-independent dual-tree algorithms," in Proceedings of the 30th
International Conference on Machine Learning (ICML '13), pp. 1435--1443, 2013.
The definition is:
A space tree on a dataset S
in R^(N x d)
is an undirected, connected,
acyclic, rooted simple graph with the following properties:
-
Each node (or vertex) holds a number of points (possibly zero) and is connected to one parent node and a number of child nodes (possibly zero).
-
There is one node in every space tree with no parent; this is the root node of the tree.
-
Each point in
S
is contained in at least one node. -
Each node corresponds to some subset of
R^d
that contains each point in the node and also the subsets that correspond to each child of the node.
This is really a quite straightforward definition: a tree is hierarchical, and each node corresponds to some region of the input space. Each node may have some number of children, and may hold some number of points. However, there is an important terminology distinction to make: the term points held by a node has a different meaning than the term descendant points held by a node. The points held in a node are just that---points held only in the node. The descendant points of a node are the combination of the points held in a node with the points held in the node's children and the points held in the node's children's children (and so forth). For the purposes of clarity in all discussions about trees, care is taken to differentiate the terms "descendant point" and "point".
Now, it's also important to note that a point does not need to hold any
children, and that a node can hold the same points as its children (or its
parent). Some types of trees do this. For instance, each node in the cover
tree holds only one point, and may have a child that holds the same point. As
another example, the kd
-tree holds its points only in the leaves (at the
bottom of the tree). More information on space trees can be found in either the
"Tree-independent dual-tree algorithms" paper or any of the related literature.
So there is a huge amount of possible variety in the types of trees that can
fall into the class of space trees. Therefore, it's important to treat them
abstractly, and the TreeType
policy allows us to do just that. All we need
to remember is that a node in a tree can be represented as the combination of
some points held in the node, some child nodes, and some geometric structure
that represents the space that all of the descendant points fall into (this is a
restatement of the fourth part of the definition).
Most everything in mlpack is decomposed into a series of configurable template
parameters, and trees are no exception. In order to ease usage of high-level
mlpack algorithms, each TreeType
itself must be a template class taking three
parameters:
MetricType
-- the underlying metric that the tree will be built on (see the MetricType policy documentation)StatisticType
-- holds any auxiliary information that individual algorithms may needMatType
-- the type of the matrix used to represent the data
The reason that these three template parameters are necessary is so that each
TreeType
can be used as a template template parameter, which can radically
simplify the required syntax for instantiating mlpack algorithms. By using
template template parameters, a user needs only to write
// The RangeSearch class takes a MetricType and a TreeType template parameter.
// This code instantiates RangeSearch with the ManhattanDistance and a
// QuadTree. Note that the QuadTree itself is a template, and takes a
// MetricType, StatisticType, and MatType, just like the policy requires.
// This example ignores the constructor parameters, for the sake of simplicity.
RangeSearch<ManhattanDistance, QuadTree> rs(...);
as opposed to the far more complicated alternative, where the user must specify the values of each template parameter of the tree type:
// This is a much worse alternative, where the user must specify the template
// arguments of their tree.
RangeSearch<ManhattanDistance,
QuadTree<ManhattanDistance, EmptyStatistic, arma::mat>> rs(...);
Unfortunately, the price to pay for this user convenience is that every
TreeType
must have three template parameters, and they must be in exactly
that order. Fortunately, there is an additional benefit: we are guaranteed that
the tree is built using the same metric as the method (that is, a user can't
specify different metric types to the algorithm and to the tree, which they can
without template template parameters).
There are two important notes about this:
-
Not every possible input of
MetricType
,StatisticType
, and/orMatType
necessarily need to be valid or work correctly for each type of tree. For instance, theQuadTree
is limited to Euclidean metrics and will not work otherwise. Either compile-time static checks or detailed documentation can help keep users from using invalid combinations of template arguments. -
Some types of trees have more template parameters than just these three. One example is the generalized binary space tree, where the bounding shape of each node is easily made into a fourth template parameter (the
BinarySpaceTree
class calls this theBoundType
parameter), and the procedure used to split a node is easily made into a fifth template parameter (theBinarySpaceTree
class calls this theSplitType
parameter). However, the syntax of template template parameters requires that the class only has the correct number of template parameters---no more, no less. Fortunately, C++11 allows template typedefs, which can be used to provide partial specialization of template classes:
// This is the definition of the BinarySpaceTree class, which has five template
// parameters.
template<typename MetricType,
typename StatisticType,
typename MatType,
typename BoundType,
typename SplitType>
class BinarySpaceTree;
// The 'using' keyword gives us a template typedef, so we can define the
// MeanSplitKDTree template class, which has three parameters and is a valid
// TreeType policy class.
template<typename MetricType, typename StatisticType, typename MatType>
using MeanSplitKDTree = BinarySpaceTree<MetricType,
StatisticType,
MatType,
HRectBound<MetricType>
MeanSplit<BoundType, MetricType>>;
Now, the MeanSplitKDTree
class has only three template parameters and can be
used as a TreeType
policy class in various mlpack algorithms. Many types of
trees in mlpack have more than three template parameters and rely on template
typedefs to provide simplified TreeType
interfaces.
As a result of the definition of space tree in the previous section, a
simplified API presents itself quite easily. However, more complex
functionality is often necessary in mlpack, so this leads to more functions
being necessary for a class to satisfy the TreeType
policy. Combining this
with the template parameters required for trees given in the previous section
gives us the complete API required for a class implementing the TreeType
policy. Below is the minimal set of functions required with minor documentation
for each function. (More extensive documentation and explanation is given
afterwards.)
// The three template parameters will be supplied by the user, and are detailed
// in the previous section.
template<typename MetricType,
typename StatisticType,
typename MatType>
class ExampleTree
{
public:
//////////////////////
//// Constructors ////
//////////////////////
// This batch constructor does not modify the dataset, and builds the entire
// tree using a default-constructed MetricType.
ExampleTree(const MatType& data);
// This batch constructor does not modify the dataset, and builds the entire
// tree using the given MetricType.
ExampleTree(const MatType& data, MetricType& metric);
// Initialize the tree from a given cereal archive. SFINAE (the
// second argument) is necessary to ensure that the archive is loading, not
// saving.
template<typename Archive>
ExampleTree(
Archive& ar,
const typename std::enable_if_c<typename Archive::is_loading>::type* = 0);
// Release any resources held by the tree.
~ExampleTree();
// ///////////////////////// //
// // Basic functionality // //
// ///////////////////////// //
// Get the dataset that the tree is built on.
const MatType& Dataset();
// Get the metric that the tree is built with.
MetricType& Metric();
// Get/modify the StatisticType for this node.
StatisticType& Stat();
// Return the parent of the node, or NULL if this is the root.
ExampleTree* Parent();
// Return the number of children held by the node.
size_t NumChildren();
// Return the i'th child held by the node.
ExampleTree& Child(const size_t i);
// Return the number of points held in the node.
size_t NumPoints();
// Return the index of the i'th point held in the node.
size_t Point(const size_t i);
// Return the number of descendant nodes of this node.
size_t NumDescendantNodes();
// Return the i'th descendant node of this node.
ExampleTree& DescendantNode(const size_t i);
// Return the number of descendant points of this node.
size_t NumDescendants();
// Return the index of the i'th descendant point of this node.
size_t Descendant(const size_t i);
// Store the center of the bounding region of the node in the given vector.
void Center(arma::vec& center);
// ///////////////////////////////////////////////// //
// // More complex distance-related functionality // //
// ///////////////////////////////////////////////// //
// Return the distance between the center of this node and the center of
// its parent.
double ParentDistance();
// Return an upper bound on the furthest possible distance between the
// center of the node and any point held in the node.
double FurthestPointDistance();
// Return an upper bound on the furthest possible distance between the
// center of the node and any descendant point of the node.
double FurthestDescendantDistance();
// Return a lower bound on the minimum distance between the center and any
// edge of the node's bounding shape.
double MinimumBoundDistance();
// Return a lower bound on the minimum distance between the given point and
// the node.
template<typename VecType>
double MinDistance(VecType& point);
// Return a lower bound on the minimum distance between the given node and
// this node.
double MinDistance(ExampleTree& otherNode);
// Return an upper bound on the maximum distance between the given point and
// the node.
template<typename VecType>
double MaxDistance(VecType& point);
// Return an upper bound on the maximum distance between the given node and
// this node.
double MaxDistance(ExampleTree& otherNode);
// Return the combined results of MinDistance() and MaxDistance().
template<typename VecType>
Range RangeDistance(VecType& point);
// Return the combined results of MinDistance() and MaxDistance().
Range RangeDistance(ExampleTree& otherNode);
// //////////////////////////////////// //
// // Serialization (loading/saving) // //
// //////////////////////////////////// //
// Return a string representation of the tree.
std::string ToString() const;
// Serialize the tree (load from the given archive / save to the given
// archive, depending on its type).
template<typename Archive>
void serialize(Archive& ar, const uint32_t version);
protected:
// A default constructor; only meant to be used by cereal. This
// must be protected so that cereal will work; it does not need
// to return a valid tree.
ExampleTree();
// Friend access must be given for the default constructor.
friend class cereal::access;
};
Although this is significantly more complex than the four-item definition of space tree might suggest, it turns out many of these methods are not difficult to implement for most reasonable tree types. It is also important to realize that this is a minimum API; you may implement more complex tree types at your leisure (and you may include more template parameters too, though you will have to use template typedefs to provide versions with three parameters; see the previous section).
Before diving into the detailed documentation for each function, let us consider a few important points about the implications of this API:
-
Trees are not default-constructible and should not (in general) provide a default constructor. This helps prevent invalid trees. In general, any instantiated mlpack object should be valid and ready to use---and a tree built on no points is not valid or ready to use.
-
Trees only need to provide batch constructors. Although many tree types do have algorithms for incremental insertions, in mlpack this is not required because the tree-based algorithms that mlpack implements generally assume fully-built, non-modifiable trees. For this purpose, batch construction is perfectly sufficient. (It's also worth pointing out that for some types of trees, like kd-trees, the cost of a handful of insertions often outweighs the cost of completely rebuilding the tree.)
-
Trees must provide a number of distance bounding functions. The utility of trees generally stems from the ability to place quick bounds on distance-related quantities. For instance, if all the descendant points of a node are bounded by a ball of radius
r
and the center of the node is a pointc
, then the minimum distance between some pointp
and any descendant point of the node is equal to the distance betweenp
andc
minus the radiusr
:d(p, c) - r
. This is a fast calculation, and (usually) provides a decent bound on the minimum distance betweenp
and any descendant point of the node. -
Trees need to be able to be serialized. mlpack uses the cereal library for saving and loading objects. Trees---which can be a part of machine learning models---therefore must have the ability to be saved and loaded. Making this all work requires a protected constructor (part of the API) and generally makes it impossible to hold references instead of pointers internally, because if a tree is loaded from a file then it must own the dataset it is built on and the metric it uses (this also means that a destructor must exist for freeing these resources).
Now, we can consider each part of the API more rigorously.
This section is divided into five parts, detailing each of the parts of the API
An earlier section discussed the three different template parameters that are
required by the TreeType
policy.
The MetricType policy provides one method that will be useful for tree building and other operations:
// This function is required by the MetricType policy.
// Evaluate the metric between two points (which may be of different types).
template<typename VecTypeA, typename VecTypeB>
double Evaluate(const VecTypeA& a, const VecTypeB& b);
Note that this method is not necessarily static, so a MetricType
object should
be held internally and its Evaluate()
method should be called whenever the
distance between two points is required. It is generally a bad idea to
hardcode any distance calculation in your tree. This will make the tree unable
to generalize to arbitrary metrics. If your tree must depend on certain
assumptions holding about the metric (i.e. the metric is a Euclidean metric),
then make that clear in the documentation of the tree, so users do not try to
use the tree with an inappropriate metric.
The second template parameter, StatisticType
, is for auxiliary information
that is required by certain algorithms. For instance, consider an algorithm
which repeatedly uses the variance of the descendant points of a node. It might
be tempting to add a Variance()
method to the required TreeType
API, but
this quickly leads to code bloat (after all, the API already has quite enough
functions as it is). Instead, it is better to create a StatisticType
class
which provides the Variance()
method, and then call Stat().Variance()
when
the variance is required. This also holds true for cached data members.
Each node should have its own instance of a StatisticType
class. The
StatisticType
must provide the following constructors:
// Default constructor required by the StatisticType policy.
StatisticType();
// This constructor is required by the StatisticType policy.
template<typename TreeType>
StatisticType(TreeType& node);
This constructor should be called with (*this)
after the node is constructed
(usually, this ends up being the last line in the constructor of a node).
The last template parameter is the MatType
parameter. This is generally
arma::mat
or arma::sp_mat
, but could be any Armadillo type, including
matrices that hold data points of different precisions (such as float
or even
int
). It generally suffices to write MatType
assuming that arma::mat
will be used, since the vast majority of the time this will be what is used.
The TreeType
API requires at least three constructors. Technically, it does
not require a destructor, but almost certainly your tree class will be doing
some memory management internally and should have one (though not always).
The first two constructors are variations of the same idea:
// This batch constructor does not modify the dataset, and builds the entire
// tree using a default-constructed MetricType.
ExampleTree(const MatType& data);
// This batch constructor does not modify the dataset, and builds the entire
// tree using the given MetricType.
ExampleTree(const MatType& data, MetricType& metric);
All that is required here is that a constructor is available that takes a
dataset and optionally an instantiated metric. If no metric is provided, then
it should be assumed that the MetricType
class has a default constructor and
a default-constructed metric should be used. The constructor must return a
valid, fully-constructed, ready-to-use tree that satisfies the definition
of space tree that was given earlier in the document.
The third constructor requires the tree to be initializable from a cereal
archive:
// Initialize the tree from a given cereal archive. SFINAE (the
// second argument) is necessary to ensure that the archive is loading, not
// saving.
template<typename Archive>
ExampleTree(
Archive& ar,
const typename std::enable_if_c<typename Archive::is_loading>::type* = 0);
This has implications on how the tree must be stored. In this case, the dataset is not yet loaded and therefore the tree may be required to have ownership of the data matrix. This means that realistically the most reasonable way to represent the data matrix internally in a tree class is not with a reference but instead with a pointer. If this is true, then a destructor will be required:
// Release any resources held by the tree.
~ExampleTree();
and, if the data matrix is represented internally with a pointer, this
destructor will need to release the memory for the data matrix (in the case that
the tree was created via cereal
).
Note that these constructors are not necessarily the only constructors that a
TreeType
implementation can provide. One important example of when more
constructors are useful is when the tree rearranges points internally; this
might be desired for the sake of speed or memory optimization. But to do this
with the required constructors would necessarily incur a copy of the data
matrix, because the user will pass a const MatType&
. One alternate solution
is to provide a constructor which takes an rvalue reference to a MatType
:
template<typename Archive>
ExampleTree(MatType&& data);
(and another overload that takes an instantiated metric), and then the user can
use std::move()
to build the tree without copying the data matrix, although
the data matrix will be modified:
ExampleTree exTree(std::move(dataset));
It is, of course, possible to add even more constructors if desired.
The basic functionality of a class implementing the TreeType
API is quite
straightforward and intuitive.
// Get the dataset that the tree is built on.
const MatType& Dataset();
This should return a const
reference to the dataset the tree is built on. The
fact that this function is required essentially means that each node in the tree
must store a pointer to the dataset (this is not the only option, but it is the
most obvious option).
// Get the metric that the tree is built with.
MetricType& Metric();
Each node must also store an instantiated metric or a pointer to one (note that
this is required even for metrics that have no state and have a static
Evaluate()
function).
// Get/modify the StatisticType for this node.
StatisticType& Stat();
As discussed earlier, each node must hold a StatisticType
; this is accessible
through the Stat()
function.
// Return the parent of the node, or NULL if this is the root.
ExampleTree* Parent();
// Return the number of children held by the node.
size_t NumChildren();
// Return the i'th child held by the node.
ExampleTree& Child(const size_t i);
// Return the number of points held in the node.
size_t NumPoints();
// Return the index of the i'th point held in the node.
size_t Point(const size_t i);
// Return the number of descendant nodes of this node.
size_t NumDescendantNodes();
// Return the i'th descendant node of this node.
ExampleTree& DescendantNode(const size_t i);
// Return the number of descendant points of this node.
size_t NumDescendants();
// Return the index of the i'th descendant point of this node.
size_t Descendant(const size_t i);
These functions are all fairly self-explanatory. Most algorithms will use the
Parent()
, Children()
, NumChildren()
, Point()
, and NumPoints()
functions, so care should be taken when implementing those functions to ensure
they will be efficient. Note that Point()
and Descendant()
should return
indices of points, so the actual points can be accessed by calling
Dataset().col(Point(i))
for some index i
(or something similar).
An important note about the Descendant()
function is that each descendant
point should be unique. So if a node holds the point with index 6 and it has
one child that holds the points with indices 6 and 7, then NumDescendants()
should return 2, not 3. The ordering in which the descendants are returned can
be arbitrary; so, Descendant(0)
can return 6 or 7, and Descendant(1)
should return the other index.
// Store the center of the bounding region of the node in the given vector.
void Center(arma::vec& center);
The last function, Center()
, should calculate the center of the bounding shape
and store it in the given vector. So, for instance, if the tree is a ball tree,
then the center is simply the center of the ball. Algorithm writers would be
wise to try and avoid the use of Center()
if possible, since it will
necessarily cost a copy of a vector.
A node in a tree should also be able to calculate various distance-related bounds; these are particularly useful in tree-based algorithms. Note that any of these bounds does not necessarily need to be maximally tight; generally it is more important that each bound can be easily calculated.
Details on each bounding function that the TreeType
API requires are given
below.
// Return the distance between the center of this node and the center of
// its parent.
double ParentDistance();
Remember that each node corresponds to some region in the space that the dataset
lies in. For most tree types this shape is often something geometrically
simple: a ball, a cone, a hyperrectangle, a slice, or something similar. The
ParentDistance()
function should return the distance between the center of
this node's region and the center of the parent node's region.
In practice this bound is often used in dual-tree (or single-tree) algorithms to
place an easy MinDistance()
(or MaxDistance()
) bound for a child node; the
parent's MinDistance()
(or MaxDistance()
) function is called and then
adjusted with ParentDistance()
to provide a possibly loose but efficient bound
on what the result of MinDistance()
(or MaxDistance()
) would be with the
child.
// Return an upper bound on the furthest possible distance between the
// center of the node and any point held in the node.
double FurthestPointDistance();
// Return an upper bound on the furthest possible distance between the
// center of the node and any descendant point of the node.
double FurthestDescendantDistance();
It is often very useful to be able to bound the radius of a node, which is
effectively what FurthestDescendantDistance()
does. Often it is easiest to
simply calculate and cache the furthest descendant distance at tree construction
time. Some trees, such as the cover tree, are able to give guarantees that the
points held in the node will necessarily be closer than the descendant points;
therefore, the FurthestPointDistance()
function is also useful.
It is permissible to simply have FurthestPointDistance()
return the result of
FurthestDescendantDistance()
, and that will still be a valid bound, but
depending on the type of tree it may be possible to have
FurthestPointDistance()
return a tighter bound.
// Return a lower bound on the minimum distance between the center and any
// edge of the node's bounding shape.
double MinimumBoundDistance();
This is, admittedly, a somewhat complex and weird quantity. It is one of the less important bounding functions, so it is valid to simply return 0...
The bound is a bound on the minimum distance between the center of the node and
any edge of the shape that bounds all of the descendants of the node. So, if
the bounding shape is a ball (as in a ball tree or a cover tree), then
MinimumBoundDistance()
should just return the radius of the ball. If the
bounding shape is a hypercube (as in a generalized octree), then
MinimumBoundDistance()
should return the side length divided by two. If the
bounding shape is a hyperrectangle (as in a kd-tree or a spill tree), then
MinimumBoundDistance()
should return half the side length of the
hyperrectangle's smallest side.
// Return a lower bound on the minimum distance between the given point and
// the node.
template<typename VecType>
double MinDistance(VecType& point);
// Return a lower bound on the minimum distance between the given node and
// this node.
double MinDistance(ExampleTree& otherNode);
// Return an upper bound on the maximum distance between the given point and
// the node.
template<typename VecType>
double MaxDistance(VecType& point);
// Return an upper bound on the maximum distance between the given node and
// this node.
double MaxDistance(ExampleTree& otherNode);
// Return the combined results of MinDistance() and MaxDistance().
template<typename VecType>
Range RangeDistance(VecType& point);
// Return the combined results of MinDistance() and MaxDistance().
Range RangeDistance(ExampleTree& otherNode);
These six functions are almost without a doubt the most important functionality of a tree. Therefore, it is preferable that these methods be implemented as efficiently as possible, as they may potentially be called many millions of times in a tree-based algorithm. It is also preferable that these bounds be as tight as possible. In tree-based algorithms, these are used for pruning away work, and tighter bounds mean that more pruning is possible.
Of these six functions, there are only really two bounds that are desired here:
the minimum distance between a node and an object, and the maximum distance
between a node and an object. The object may be either a vector (usually
arma::vec
) or another tree node.
Consider the first case, where the object is a vector. The result of
MinDistance()
needs to be less than or equal to the true minimum distance,
which could be calculated as below:
// We assume that we have a vector 'vec', and a tree node 'node'.
double trueMinDist = DBL_MAX;
for (size_t i = 0; i < node.NumDescendants(); ++i)
{
const double dist = node.Metric().Evaluate(vec,
node.Dataset().col(node.Descendant(i)));
if (dist < trueMinDist)
trueMinDist = dist;
}
// At the end of the loop, trueMinDist will hold the true minimum distance
// between 'vec' and any descendant point of 'node'.
Often the bounding shape of a node will allow a quick calculation that will make
a reasonable bound. For instance, if the node's bounding shape is a ball with
radius r
and center ctr
, the calculation is simply
(node.Metric().Evaluate(vec, ctr) - r)
. Usually a good MinDistance()
or
MaxDistance()
function will make only one call to the Evaluate()
function of
the metric.
The RangeDistance()
function allows a way for both bounds to be calculated at
once. It is possible to implement this as a call to MinDistance()
followed by
a call to MaxDistance()
, but this may incur more metric Evaluate()
calls
than necessary. Often calculating both bounds at once can be more efficient and
can be done with fewer Evaluate()
calls than calling both MinDistance()
and
MaxDistance()
.
The last functions that the TreeType
API requires are for serialization.
// Serialize the tree (load from the given archive / save to the given
// archive, depending on its type).
template<typename Archive>
void serialize(Archive& ar, const unsigned int version);
protected:
// A default constructor; only meant to be used by cereal. This
// must be protected so that cereal will work; it does not need
// to return a valid tree.
ExampleTree();
// Friend access must be given for the default constructor.
friend class cereal::access;
On the other hand, the specifics of the functionality required for the
serialize()
function are somewhat more difficult. The serialize()
function
will be called either when a tree is being saved to disk or loaded from disk.
The cereal
documentation is fairly comprehensive.
An important note is that it is very difficult to use references with cereal
,
because serialize()
may be called at any time during the object's lifetime,
and references cannot be re-seated. In general this will require the use of
pointers, which then require manual memory management. Therefore, be careful
that serialize()
(and the tree's destructor) properly handle memory
management!
Some tree-based algorithms can specialize if the tree fulfills certain
conditions. For instance, if the regions represented by two sibling nodes
cannot overlap, an algorithm may be able to perform a simpler computation.
Based on this reasoning, the TreeTraits
trait class (much like the
KernelTraits
class) exists in order to allow a tree to specify (via a const static bool
) when these types of conditions are satisfied. Note that a
TreeTraits class is not required, but may be helpful.
The TreeTraits
trait class is a template class that takes a TreeType
as a
parameter, and exposes const static bool
values that depend on the tree.
Setting these values is achieved by specialization. The code below shows the
default TreeTraits
values (these are the values that will be used if no
specialization is provided for a given TreeType
).
template<typename TreeType>
class TreeTraits
{
public:
// This is true if the subspaces represented by the children of a node can
// overlap.
static const bool HasOverlappingChildren = true;
// This is true if Point(0) is the centroid of the node.
static const bool FirstPointIsCentroid = false;
// This is true if the points contained in the first child of a node
// (Child(0)) are also contained in that node.
static const bool HasSelfChildren = false;
// This is true if the tree rearranges points in the dataset when it is built.
static const bool RearrangesDataset = false;
// This is true if the tree always has only two children.
static const bool BinaryTree = false;
};
An example specialization for the :KDTree
class is given below. Note that
KDTree
is itself a template class (like every class satisfying the TreeType
policy), so we are specializing to a template parameter.
template<typename MetricType,
typename StatisticType,
typename MatType>
template<>
class TreeTraits<KDTree<MetricType, StatisticType, MatType>>
{
public:
// The regions represented by the two children of a node may not overlap.
static const bool HasOverlappingChildren = false;
// There is no guarantee that the first point of a node is the centroid.
static const bool FirstPointIsCentroid = false;
// Points are not contained at multiple levels (only at the leaves).
static const bool HasSelfChildren = false;
// Points are rearranged during the building of the tree.
static const bool RearrangesDataset = true;
// The tree is always binary.
static const bool BinaryTree = true;
};
Currently, the traits available are each of the five detailed above. For more
information, see the TreeTraits
source code for more documentation.
mlpack contains several ready-to-use implementations of trees that satisfy the TreeType policy API:
KDTree
MeanSplitKDTree
BallTree
MeanSplitBallTree
RTree
RStarTree
StandardCoverTree
Often, these are template typedefs of more flexible tree classes:
BinarySpaceTree
-- binary trees, such as the KD-tree and ball treeRectangleTree
-- the R tree and variantsCoverTree
-- the cover tree and variants