From ad31474bf2a53807eb2c50330287d0bff823ca85 Mon Sep 17 00:00:00 2001 From: Michael Ripperger Date: Fri, 29 Sep 2023 18:02:40 -0500 Subject: [PATCH] Added check to ensure seed state parameter is set correctly --- include/reach/reach_study.h | 6 ++++++ src/reach_study.cpp | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/include/reach/reach_study.h b/include/reach/reach_study.h index c91f879f..8a2a5129 100644 --- a/include/reach/reach_study.h +++ b/include/reach/reach_study.h @@ -85,6 +85,12 @@ class ReachStudy std::tuple getAverageNeighborsCount() const; protected: + /** + * @brief Checks the seed state parameter for validity and sets it to a default value if it does not exist + * @throws Throws an exception if all of the IK solver joint names cannot be found in the seed state parameter + */ + void checkSeedState(); + Parameters params_; ReachDatabase db_; diff --git a/src/reach_study.cpp b/src/reach_study.cpp index 60bc5ed7..08e8d8fc 100644 --- a/src/reach_study.cpp +++ b/src/reach_study.cpp @@ -38,6 +38,7 @@ ReachStudy::ReachStudy(IKSolver::ConstPtr ik_solver, Evaluator::ConstPtr evaluat , target_poses_(target_generator->generate()) , search_tree_(createSearchTree(target_poses_)) { + checkSeedState(); } ReachStudy::ReachStudy(const ReachStudy& rhs) @@ -49,6 +50,23 @@ ReachStudy::ReachStudy(const ReachStudy& rhs) , target_poses_(rhs.target_poses_) , search_tree_(rhs.search_tree_) { + checkSeedState(); +} + +void ReachStudy::checkSeedState() +{ + // Check the optional seed state parameter + const std::vector joint_names = ik_solver_->getJointNames(); + if(params_.seed_state.empty()) + { + logger_->print("Seed state is empty; setting to all-zeros state"); + params_.seed_state = zip(joint_names, std::vector(joint_names.size(), 0.0)); + } + else + { + // Attempt to extract a subset of the seed state for the provided joint names. This function throws an exception if this is not possible + transcribeInputMap(params_.seed_state, joint_names); + } } void ReachStudy::load(const std::string& filename)