Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix minNegLogConstant after pruning #1893

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
return result;
}

/* ************************************************************************* */
double HybridBayesNet::negLogConstant(
const std::optional<DiscreteValues> &discrete) const {
double negLogNormConst = 0.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (discrete.has_value()) {
if (auto gm = conditional->asHybrid()) {
negLogNormConst += gm->choose(*discrete)->negLogConstant();
} else if (auto gc = conditional->asGaussian()) {
negLogNormConst += gc->negLogConstant();
} else if (auto dc = conditional->asDiscrete()) {
negLogNormConst += dc->choose(*discrete)->negLogConstant();
} else {
throw std::runtime_error(
"Unknown conditional type when computing negLogConstant");
}
} else {
negLogNormConst += conditional->negLogConstant();
}
}
return negLogNormConst;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
const VectorValues &continuousValues) const {
Expand Down
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

using BayesNet::logProbability; // expose HybridValues version

/**
* @brief Get the negative log of the normalization constant
* corresponding to the joint density represented by this Bayes net.
* Optionally index by `discrete`.
*
* @param discrete Optional DiscreteValues
* @return double
*/
double negLogConstant(const std::optional<DiscreteValues> &discrete) const;

/**
* @brief Compute normalized posterior P(M|X=x) and return as a tree.
*
Expand Down
7 changes: 5 additions & 2 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,11 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
else {
// Add negLogConstant_ back so that the minimum negLogConstant in the
// HybridGaussianConditional is set correctly.
return {pair.first, pair.second + negLogConstant_};
}
};

FactorValuePairs prunedConditionals = factors().apply(pruner);
Expand Down
7 changes: 4 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;

/// Result from elimination.
struct Result {
// Gaussian conditional resulting from elimination.
GaussianConditional::shared_ptr conditional;
double negLogK;
GaussianFactor::shared_ptr factor;
double scalar;
double negLogK; // Negative log of the normalization constant K.
GaussianFactor::shared_ptr factor; // Leftover factor 𝜏.
double scalar; // Scalar value associated with factor 𝜏.

bool operator==(const Result &other) const {
return conditional == other.conditional && negLogK == other.negLogK &&
Expand Down
42 changes: 34 additions & 8 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,6 @@ TEST(HybridBayesNet, Pruning) {
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
EXPECT(assert_equal(expected, discretePosterior, 1e-6));

// Prune and get probabilities
auto prunedBayesNet = posterior->prune(2);
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());

// Verify logProbability computation and check specific logProbability value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
Expand All @@ -381,18 +377,48 @@ TEST(HybridBayesNet, Pruning) {
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9);

double negLogConstant = posterior->negLogConstant(discrete_values);

// The sum of all the mode densities
double normalizer =
AlgebraicDecisionTree<Key>(posterior->errorTree(delta.continuous()),
[](double error) { return exp(-error); })
.sum();

// Check agreement with discrete posterior
// double density = exp(logProbability);
// FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values),
// 1e-6);
double density = exp(logProbability + negLogConstant) / normalizer;
EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6);

// Prune and get probabilities
auto prunedBayesNet = posterior->prune(2);
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());

// Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));

// Regression
// FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
double pruned_logProbability = 0;
pruned_logProbability +=
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues);

double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);

// The sum of all the mode densities
double pruned_normalizer =
AlgebraicDecisionTree<Key>(prunedBayesNet.errorTree(delta.continuous()),
[](double error) { return exp(-error); })
.sum();
double pruned_density =
exp(pruned_logProbability + pruned_negLogConstant) / pruned_normalizer;
EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9);
}

/* ****************************************************************************/
Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/tests/testHybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ TEST(HybridGaussianConditional, Prune) {

// Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());

// Check that the minimum negLogConstant is set correctly
EXPECT_DOUBLES_EQUAL(
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
pruned->negLogConstant(), 1e-9);
}
{
const std::vector<double> potentials{0.2, 0, 0.3, 0, //
Expand All @@ -285,6 +290,9 @@ TEST(HybridGaussianConditional, Prune) {

// Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());

// Check that the minimum negLogConstant is correct
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9);
}
}

Expand Down
Loading