Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlgebraicDecisionTree Helpers #1696

Merged
merged 54 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
7695fd6
Improved HybridBayesNet::optimize with proper model selection
varunagrawal Nov 20, 2023
39f7ac2
handle nullptrsin GaussianMixture::error
varunagrawal Nov 20, 2023
c374a26
nicer HybridBayesNet::optimize with normalized errors
varunagrawal Nov 20, 2023
ed5ef66
Merge branch 'hybrid-printerrors' into model-selection-integration
varunagrawal Nov 27, 2023
50670da
HybridValues formatting
varunagrawal Dec 12, 2023
af490e9
sum and normalize helper methods for the AlgebraicDecisionTree
varunagrawal Dec 12, 2023
c004bd8
test for differing covariances
varunagrawal Dec 13, 2023
7b56c96
differing means test
varunagrawal Dec 15, 2023
e549a9b
normalize model selection term
varunagrawal Dec 15, 2023
b2638c8
max and min functions for AlgebraicDecisionTree
varunagrawal Dec 17, 2023
6f09be5
error normalization and log-sum-exp trick
varunagrawal Dec 17, 2023
3660429
handle numerical instability
varunagrawal Dec 18, 2023
07ddec5
remove stray comment
varunagrawal Dec 21, 2023
c6584f6
minor test cleanup
varunagrawal Dec 22, 2023
ebcf958
better, more correct version of model selection
varunagrawal Dec 25, 2023
1e298be
Better way of handling assignments
varunagrawal Dec 25, 2023
b4f07a0
cleaner model selection computation
varunagrawal Dec 26, 2023
6f4343c
almost working
varunagrawal Dec 26, 2023
409938f
improved model selection code
varunagrawal Dec 26, 2023
93c824c
overload == operator for GaussianBayesNet and VectorValues
varunagrawal Dec 27, 2023
b20d33d
logNormalizationConstant() for GaussianBayesNet
varunagrawal Dec 27, 2023
3a89653
helper methods in GaussianMixture for model selection
varunagrawal Dec 27, 2023
6f66d04
handle pruning in model selection
varunagrawal Dec 27, 2023
0d05810
update wrapper for LM with Ordering parameter
varunagrawal Jan 3, 2024
651f999
print logNormalizationConstant for Gaussian conditionals
varunagrawal Jan 3, 2024
114c86f
GaussianConditional wrapper for arbitrary number of keys
varunagrawal Jan 3, 2024
82e0c0d
take comment all the way
varunagrawal Jan 3, 2024
8a61c49
add model_selection method to HybridBayesNet
varunagrawal Jan 3, 2024
3ba54eb
improved docstrings
varunagrawal Jan 3, 2024
bb95cd4
remove `using std::dynamic_pointer_cast;`
varunagrawal Jan 3, 2024
6d50de8
docstring for HybridBayesNet::assembleTree
varunagrawal Jan 3, 2024
9ad7697
Merge branch 'hybrid-printerrors' into model-selection-integration
varunagrawal Jan 4, 2024
502e8cf
Merge branch 'model-selection-integration' into hybrid-lognormconstant
varunagrawal Jan 4, 2024
a80b5d4
Merge branch 'hybrid-printerrors' into model-selection-integration
varunagrawal Jan 5, 2024
0430fee
improved naming and documentation
varunagrawal Jan 7, 2024
afcb933
document return type
varunagrawal Jan 7, 2024
c5bfd52
better printing of GaussianMixtureFactor
varunagrawal Jan 12, 2024
e9e2ef9
Merge pull request #1705 from borglab/hybrid-lognormconstant
varunagrawal Feb 20, 2024
538871a
Merge branch 'develop' into model-selection-integration
varunagrawal Mar 18, 2024
1501b7c
Merge branch 'develop' into model-selection-integration
varunagrawal Jun 28, 2024
eb9ea78
Merge branch 'develop' into model-selection-integration
varunagrawal Jul 2, 2024
a9cf4a0
fix namespacing
varunagrawal Jul 3, 2024
2a080bb
Merge branch 'develop' into model-selection-integration
varunagrawal Jul 29, 2024
113a7f8
added more comments and compute GaussianMixture before tau
varunagrawal Aug 5, 2024
2430abb
test for different error values in BN from MixtureFactor
varunagrawal Aug 7, 2024
3c722ac
update GaussianMixtureFactor to record normalizers, and add unit tests
varunagrawal Aug 20, 2024
d4e5a9b
different means test both via direct factor definition and toFactorGraph
varunagrawal Aug 20, 2024
fef929f
clean up model selection
varunagrawal Aug 20, 2024
654bad7
remove model selection code
varunagrawal Aug 20, 2024
6b1d89d
fix testMixtureFactor test
varunagrawal Aug 20, 2024
6d9fc8e
undo change in GaussianMixture
varunagrawal Aug 20, 2024
fd2062b
remove changes so we can break up PR into smaller ones
varunagrawal Aug 20, 2024
cea84b8
reduce the diff even more
varunagrawal Aug 20, 2024
73d971a
unit tests for AlgebraicDecisionTree helper methods
varunagrawal Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,42 @@ namespace gtsam {
return this->apply(g, &Ring::div);
}

/// Compute sum of all values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests?

double sum() const {
double sum = 0;
auto visitor = [&](double y) { sum += y; };
this->visit(visitor);
return sum;
}

/**
* @brief Helper method to perform normalization such that all leaves in the
* tree sum to 1
*
* @param sum
* @return AlgebraicDecisionTree
*/
AlgebraicDecisionTree normalize(double sum) const {
return this->apply([&sum](const double& x) { return x / sum; });
}

/// Find the minimum values amongst all leaves
double min() const {
double min = std::numeric_limits<double>::max();
auto visitor = [&](double x) { min = x < min ? x : min; };
this->visit(visitor);
return min;
}

/// Find the maximum values amongst all leaves
double max() const {
// Get the most negative value
double max = -std::numeric_limits<double>::max();
auto visitor = [&](double x) { max = x > max ? x : max; };
this->visit(visitor);
return max;
}

/** sum out variable */
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
return this->combine(label, cardinality, &Ring::add);
Expand Down
6 changes: 4 additions & 2 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
* @brief Convert a DecisionTree of factors into
* a DecisionTree of Gaussian factor graphs.
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

Expand Down Expand Up @@ -214,7 +215,8 @@ class GTSAM_EXPORT GaussianMixture
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand Down
15 changes: 6 additions & 9 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
/**
* @file HybridValues.h
* @date Jul 28, 2022
* @author Varun Agrawal
* @author Shangjie Xue
*/

Expand Down Expand Up @@ -54,13 +55,13 @@ class GTSAM_EXPORT HybridValues {
HybridValues() = default;

/// Construct from DiscreteValues and VectorValues.
HybridValues(const VectorValues &cv, const DiscreteValues &dv)
: continuous_(cv), discrete_(dv){}
HybridValues(const VectorValues& cv, const DiscreteValues& dv)
: continuous_(cv), discrete_(dv) {}

/// Construct from all values types.
HybridValues(const VectorValues& cv, const DiscreteValues& dv,
const Values& v)
: continuous_(cv), discrete_(dv), nonlinear_(v){}
: continuous_(cv), discrete_(dv), nonlinear_(v) {}

/// @}
/// @name Testable
Expand Down Expand Up @@ -101,9 +102,7 @@ class GTSAM_EXPORT HybridValues {
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }

/// Check whether a variable with key \c j exists in values.
bool existsNonlinear(Key j) {
return nonlinear_.exists(j);
}
bool existsNonlinear(Key j) { return nonlinear_.exists(j); }

/// Check whether a variable with key \c j exists.
bool exists(Key j) {
Expand All @@ -128,9 +127,7 @@ class GTSAM_EXPORT HybridValues {
}

/// insert_or_assign() , similar to Values.h
void insert_or_assign(Key j, size_t value) {
discrete_[j] = value;
}
void insert_or_assign(Key j, size_t value) { discrete_[j] = value; }

/** Insert all continuous values from \c values. Throws an invalid_argument
* exception if any keys to be inserted are already used. */
Expand Down
1 change: 0 additions & 1 deletion gtsam/hybrid/tests/testHybridEstimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ TEST(HybridEstimation, Probability) {
for (auto discrete_conditional : *discreteBayesNet) {
bayesNet->add(discrete_conditional);
}
auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();

HybridValues hybrid_values = bayesNet->optimize();

Expand Down
11 changes: 8 additions & 3 deletions gtsam/nonlinear/nonlinear.i
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,22 @@ typedef gtsam::GncOptimizer<gtsam::GncParams<gtsam::LevenbergMarquardtParams>> G
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
virtual class LevenbergMarquardtOptimizer : gtsam::NonlinearOptimizer {
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
const gtsam::Values& initialValues);
const gtsam::Values& initialValues,
const gtsam::LevenbergMarquardtParams& params =
gtsam::LevenbergMarquardtParams());
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
const gtsam::Values& initialValues,
const gtsam::LevenbergMarquardtParams& params);
const gtsam::Ordering& ordering,
const gtsam::LevenbergMarquardtParams& params =
gtsam::LevenbergMarquardtParams());

double lambda() const;
void print(string s = "") const;
};

#include <gtsam/nonlinear/ISAM2.h>
class ISAM2GaussNewtonParams {
ISAM2GaussNewtonParams();
ISAM2GaussNewtonParams(double _wildfireThreshold = 0.001);

void print(string s = "") const;

Expand Down
Loading