-
Notifications
You must be signed in to change notification settings - Fork 767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simpler HybridGaussianFoo constructors #1848
Changes from 22 commits
1f11b54
977112d
b45ba00
ebebf7d
e04f0af
e8089dc
35dd42e
08cf399
69b1313
83fae8e
e18dd3e
2c12e68
bb4c3c9
1a566ea
71d5a6c
3e6227f
d93ebea
14ad127
bc555ae
2b26b3c
7d51e1c
bc25fce
ce45bb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,57 +27,66 @@ | |
#include <gtsam/linear/GaussianBayesNet.h> | ||
#include <gtsam/linear/GaussianFactorGraph.h> | ||
|
||
#include <cstddef> | ||
|
||
namespace gtsam { | ||
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs( | ||
const HybridGaussianConditional::Conditionals &conditionals) { | ||
auto func = [](const GaussianConditional::shared_ptr &conditional) | ||
-> GaussianFactorValuePair { | ||
double value = 0.0; | ||
// Check if conditional is pruned | ||
if (conditional) { | ||
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|)) | ||
value = conditional->negLogConstant(); | ||
/* *******************************************************************************/ | ||
struct HybridGaussianConditional::ConstructorHelper { | ||
std::optional<size_t> nrFrontals; | ||
HybridGaussianFactor::FactorValuePairs pairs; | ||
double minNegLogConstant; | ||
|
||
/// Compute all variables needed for the private constructor below. | ||
ConstructorHelper(const Conditionals &conditionals) | ||
: minNegLogConstant(std::numeric_limits<double>::infinity()) { | ||
auto func = [this](const GaussianConditional::shared_ptr &c) | ||
-> GaussianFactorValuePair { | ||
double value = 0.0; | ||
if (c) { | ||
if (!nrFrontals.has_value()) { | ||
nrFrontals = c->nrFrontals(); | ||
} | ||
value = c->negLogConstant(); | ||
minNegLogConstant = std::min(minNegLogConstant, value); | ||
} | ||
return {std::dynamic_pointer_cast<GaussianFactor>(c), value}; | ||
}; | ||
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func); | ||
if (!nrFrontals.has_value()) { | ||
throw std::runtime_error( | ||
"HybridGaussianConditional: need at least one frontal variable."); | ||
} | ||
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value}; | ||
}; | ||
return HybridGaussianFactor::FactorValuePairs(conditionals, func); | ||
} | ||
} | ||
}; | ||
|
||
/* *******************************************************************************/ | ||
HybridGaussianConditional::HybridGaussianConditional( | ||
const DiscreteKeys &discreteParents, | ||
const HybridGaussianConditional::Conditionals &conditionals, | ||
const ConstructorHelper &helper) | ||
: BaseFactor(discreteParents, helper.pairs), | ||
BaseConditional(*helper.nrFrontals), | ||
conditionals_(conditionals), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These things are here twice! Once in BaseFactor, and once in conditionals_. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I thought about that too. There was a reason I couldn't get rid of |
||
negLogConstant_(helper.minNegLogConstant) {} | ||
|
||
HybridGaussianConditional::HybridGaussianConditional( | ||
dellaert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const KeyVector &continuousFrontals, const KeyVector &continuousParents, | ||
const DiscreteKeys &discreteParents, | ||
const HybridGaussianConditional::Conditionals &conditionals) | ||
: BaseFactor(CollectKeys(continuousFrontals, continuousParents), | ||
discreteParents, GetFactorValuePairs(conditionals)), | ||
BaseConditional(continuousFrontals.size()), | ||
conditionals_(conditionals) { | ||
// Calculate negLogConstant_ as the minimum of the negative-log normalizers of | ||
// the conditionals, by visiting the decision tree: | ||
negLogConstant_ = std::numeric_limits<double>::infinity(); | ||
conditionals_.visit( | ||
[this](const GaussianConditional::shared_ptr &conditional) { | ||
if (conditional) { | ||
this->negLogConstant_ = | ||
std::min(this->negLogConstant_, conditional->negLogConstant()); | ||
} | ||
}); | ||
} | ||
: HybridGaussianConditional(discreteParents, conditionals, | ||
ConstructorHelper(conditionals)) {} | ||
|
||
HybridGaussianConditional::HybridGaussianConditional( | ||
dellaert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const DiscreteKey &discreteParent, | ||
const std::vector<GaussianConditional::shared_ptr> &conditionals) | ||
: HybridGaussianConditional(DiscreteKeys{discreteParent}, | ||
Conditionals({discreteParent}, conditionals)) {} | ||
|
||
/* *******************************************************************************/ | ||
const HybridGaussianConditional::Conditionals & | ||
HybridGaussianConditional::conditionals() const { | ||
return conditionals_; | ||
} | ||
|
||
/* *******************************************************************************/ | ||
HybridGaussianConditional::HybridGaussianConditional( | ||
const KeyVector &continuousFrontals, const KeyVector &continuousParents, | ||
const DiscreteKey &discreteParent, | ||
const std::vector<GaussianConditional::shared_ptr> &conditionals) | ||
: HybridGaussianConditional(continuousFrontals, continuousParents, | ||
DiscreteKeys{discreteParent}, | ||
Conditionals({discreteParent}, conditionals)) {} | ||
|
||
/* *******************************************************************************/ | ||
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() | ||
const { | ||
|
@@ -222,8 +231,8 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood( | |
return {likelihood_m, Cgm_Kgcm}; | ||
} | ||
}); | ||
return std::make_shared<HybridGaussianFactor>( | ||
continuousParentKeys, discreteParentKeys, likelihoods); | ||
return std::make_shared<HybridGaussianFactor>(discreteParentKeys, | ||
likelihoods); | ||
} | ||
|
||
/* ************************************************************************* */ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,28 +21,20 @@ | |
#include <gtsam/base/utilities.h> | ||
#include <gtsam/discrete/DecisionTree-inl.h> | ||
#include <gtsam/discrete/DecisionTree.h> | ||
#include <gtsam/hybrid/HybridFactor.h> | ||
#include <gtsam/hybrid/HybridGaussianFactor.h> | ||
#include <gtsam/hybrid/HybridValues.h> | ||
#include <gtsam/linear/GaussianFactor.h> | ||
#include <gtsam/linear/GaussianFactorGraph.h> | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* @brief Helper function to augment the [A|b] matrices in the factor components | ||
* with the additional scalar values. | ||
* This is done by storing the value in | ||
* the `b` vector as an additional row. | ||
* | ||
* @param factors DecisionTree of GaussianFactors and arbitrary scalars. | ||
* Gaussian factor in factors. | ||
* @return HybridGaussianFactor::Factors | ||
*/ | ||
static HybridGaussianFactor::Factors augment( | ||
const HybridGaussianFactor::FactorValuePairs &factors) { | ||
/* *******************************************************************************/ | ||
HybridGaussianFactor::Factors HybridGaussianFactor::augment( | ||
const FactorValuePairs &factors) { | ||
// Find the minimum value so we can "proselytize" to positive values. | ||
// Done because we can't have sqrt of negative numbers. | ||
HybridGaussianFactor::Factors gaussianFactors; | ||
Factors gaussianFactors; | ||
AlgebraicDecisionTree<Key> valueTree; | ||
std::tie(gaussianFactors, valueTree) = unzip(factors); | ||
|
||
|
@@ -73,22 +65,88 @@ static HybridGaussianFactor::Factors augment( | |
return std::dynamic_pointer_cast<GaussianFactor>( | ||
std::make_shared<JacobianFactor>(gfg)); | ||
}; | ||
return HybridGaussianFactor::Factors(factors, update); | ||
return Factors(factors, update); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
HybridGaussianFactor::HybridGaussianFactor(const KeyVector &continuousKeys, | ||
const DiscreteKeys &discreteKeys, | ||
struct HybridGaussianFactor::ConstructorHelper { | ||
KeyVector continuousKeys; // Continuous keys extracted from factors | ||
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors | ||
FactorValuePairs pairs; // Used only if factorsTree is empty | ||
Factors factorsTree; | ||
|
||
ConstructorHelper(const DiscreteKey &discreteKey, | ||
const std::vector<GaussianFactor::shared_ptr> &factors) | ||
: discreteKeys({discreteKey}) { | ||
// Extract continuous keys from the first non-null factor | ||
for (const auto &factor : factors) { | ||
if (factor && continuousKeys.empty()) { | ||
continuousKeys = factor->keys(); | ||
break; | ||
} | ||
} | ||
|
||
// Build the DecisionTree from the factor vector | ||
factorsTree = Factors(discreteKeys, factors); | ||
} | ||
|
||
ConstructorHelper(const DiscreteKey &discreteKey, | ||
const std::vector<GaussianFactorValuePair> &factorPairs) | ||
: discreteKeys({discreteKey}) { | ||
// Extract continuous keys from the first non-null factor | ||
for (const auto &pair : factorPairs) { | ||
if (pair.first && continuousKeys.empty()) { | ||
continuousKeys = pair.first->keys(); | ||
break; | ||
} | ||
} | ||
|
||
// Build the FactorValuePairs DecisionTree | ||
pairs = FactorValuePairs(discreteKeys, factorPairs); | ||
} | ||
|
||
ConstructorHelper(const DiscreteKeys &discreteKeys, | ||
const FactorValuePairs &factorPairs) | ||
: discreteKeys(discreteKeys) { | ||
// Extract continuous keys from the first non-null factor | ||
factorPairs.visit([&](const GaussianFactorValuePair &pair) { | ||
if (pair.first && continuousKeys.empty()) { | ||
continuousKeys = pair.first->keys(); | ||
} | ||
}); | ||
|
||
// Build the FactorValuePairs DecisionTree | ||
pairs = factorPairs; | ||
} | ||
}; | ||
|
||
/* *******************************************************************************/ | ||
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper) | ||
: Base(helper.continuousKeys, helper.discreteKeys), | ||
factors_(helper.factorsTree.empty() ? augment(helper.pairs) | ||
: helper.factorsTree) {} | ||
|
||
HybridGaussianFactor::HybridGaussianFactor( | ||
const DiscreteKey &discreteKey, | ||
const std::vector<GaussianFactor::shared_ptr> &factors) | ||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {} | ||
|
||
HybridGaussianFactor::HybridGaussianFactor( | ||
const DiscreteKey &discreteKey, | ||
const std::vector<GaussianFactorValuePair> &factorPairs) | ||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {} | ||
|
||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All three of the above constructors need the divider There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more for me when I add it. |
||
const FactorValuePairs &factors) | ||
: Base(continuousKeys, discreteKeys), factors_(augment(factors)) {} | ||
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {} | ||
|
||
/* *******************************************************************************/ | ||
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { | ||
const This *e = dynamic_cast<const This *>(&lf); | ||
if (e == nullptr) return false; | ||
|
||
// This will return false if either factors_ is empty or e->factors_ is empty, | ||
// but not if both are empty or both are not empty: | ||
// This will return false if either factors_ is empty or e->factors_ is | ||
// empty, but not if both are empty or both are not empty: | ||
if (factors_.empty() ^ e->factors_.empty()) return false; | ||
|
||
// Check the base and the factors: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we store the value, we should never have to call negLogConstant in this file again!
And that’s why we currently need to keep conditionals :-( at least until the hiding is removed.