Skip to content

Commit

Permalink
Update TaskComposerNodeInfo to allow searching graph
Browse files Browse the repository at this point in the history
  • Loading branch information
Levi-Armstrong committed Aug 15, 2024
1 parent 4895a8e commit 264b3d4
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
matrix:
distro: [focal, jammy, noble]
container:
image: ghcr.io/tesseract-robotics/trajopt:${{ matrix.distro }}-0.23
image: ghcr.io/tesseract-robotics/trajopt:${{ matrix.distro }}-0.24
env:
CCACHE_DIR: "$GITHUB_WORKSPACE/${{ matrix.distro }}/.ccache"
DEBIAN_FRONTEND: noninteractive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class TaskComposerGraph : public TaskComposerNode
*/
void setTerminalTriggerAbortByIndex(int terminal_index);

/** Get the abort terminal uuid if set */
boost::uuids::uuid getAbortTerminal() const;

/** Get the abort terminal index if set */
int getAbortTerminalIndex() const;

/**
* @brief Check if the current state of the graph is valid
* @todo Replace return type with std::expected when upgraded to use c++23
Expand Down Expand Up @@ -149,6 +155,7 @@ class TaskComposerGraph : public TaskComposerNode

std::map<boost::uuids::uuid, TaskComposerNode::Ptr> nodes_;
std::vector<boost::uuids::uuid> terminals_;
int abort_terminal_{ -1 };
};

} // namespace tesseract_planning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP

#include <tesseract_task_composer/core/task_composer_keys.h>
#include <tesseract_task_composer/core/task_composer_node_ports.h>
#include <tesseract_task_composer/core/task_composer_node_info.h>

namespace YAML
{
Expand All @@ -53,6 +52,7 @@ namespace tesseract_planning
class TaskComposerDataStorage;
class TaskComposerContext;
class TaskComposerExecutor;
class TaskComposerNodeInfo;

enum class TaskComposerNodeType
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP

#include <tesseract_common/any_poly.h>

#include <tesseract_task_composer/core/task_composer_node.h>
#include <tesseract_task_composer/core/task_composer_keys.h>
#include <tesseract_task_composer/core/task_composer_data_storage.h>

namespace tesseract_planning
{
class TaskComposerNode;

/** Stores information about a node */
class TaskComposerNodeInfo
{
Expand Down Expand Up @@ -79,6 +78,15 @@ class TaskComposerNodeInfo
*/
boost::uuids::uuid parent_uuid{};

/** @brief The node type */
TaskComposerNodeType type{ TaskComposerNodeType::NODE };

/** @brief The task type hash code from std::type_index */
std::size_t type_hash_code{ 0 };

/** @brief The task is conditional or not */
bool conditional{ false };

/** @brief The nodes inbound edges */
std::vector<boost::uuids::uuid> inbound_edges;

Expand All @@ -91,6 +99,12 @@ class TaskComposerNodeInfo
/** @brief The output keys */
TaskComposerKeys output_keys;

/** @brief The graph of pipeline terminals */
std::vector<boost::uuids::uuid> terminals;

/** @brief Indicate if abort terminal was assigned. Only valid for graph and pipelines */
int abort_terminal{ -1 };

/** @brief Value returned from the Task on completion */
int return_value{ -1 };

Expand Down Expand Up @@ -189,6 +203,12 @@ class TaskComposerNodeInfoContainer
/** @brief Merge the contents of another container's info map */
void mergeInfoMap(TaskComposerNodeInfoContainer&& container);

/** @brief Set the root node */
void setRootNode(const boost::uuids::uuid& node_uuid);

/** @brief Get the root node */
boost::uuids::uuid getRootNode() const;

/**
* @brief Called if aborted
* @details This is set if abort is called in input
Expand Down Expand Up @@ -217,6 +237,7 @@ class TaskComposerNodeInfoContainer
void serialize(Archive& ar, const unsigned int version); // NOLINT

mutable std::shared_mutex mutex_;
boost::uuids::uuid root_node_{};
boost::uuids::uuid aborting_node_{};
std::map<boost::uuids::uuid, TaskComposerNodeInfo::UPtr> info_map_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@ inline void runSerializationTest(const T& input, const std::string& file_name)
EXPECT_FALSE(input != ninput);
}

template <typename T>
inline void runSerializationTestNotEqual(const T& input, const std::string& file_name)
{
const std::string filepath = tesseract_common::getTempPath() + file_name + ".xml";
tesseract_common::Serialization::toArchiveFileXML<T>(input, filepath);
auto ninput = tesseract_common::Serialization::fromArchiveFileXML<T>(filepath);
EXPECT_TRUE(input != ninput);
}

template <typename T>
inline void runSerializationPointerTest(const T& input, const std::string& file_name)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::unique_ptr<TaskComposerNodeInfo> HasDataStorageEntryTask::runImpl(TaskCompo

bool HasDataStorageEntryTask::operator==(const HasDataStorageEntryTask& rhs) const
{
return (TaskComposerNode::operator==(rhs));
return (TaskComposerTask::operator==(rhs));
}
bool HasDataStorageEntryTask::operator!=(const HasDataStorageEntryTask& rhs) const { return !operator==(rhs); }

Expand Down
4 changes: 3 additions & 1 deletion tesseract_task_composer/core/src/task_composer_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ std::unique_ptr<TaskComposerFuture> TaskComposerExecutor::run(const TaskComposer
std::shared_ptr<TaskComposerDataStorage> data_storage,
bool dotgraph)
{
return run(node, std::make_shared<TaskComposerContext>(node.getName(), std::move(data_storage), dotgraph));
auto context = std::make_shared<TaskComposerContext>(node.getName(), std::move(data_storage), dotgraph);
context->task_infos.setRootNode(node.getUUID());
return run(node, context);
}

bool TaskComposerExecutor::operator==(const TaskComposerExecutor& rhs) const { return (name_ == rhs.name_); }
Expand Down
44 changes: 37 additions & 7 deletions tesseract_task_composer/core/src/task_composer_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,31 @@ void TaskComposerGraph::setTerminalTriggerAbort(boost::uuids::uuid terminal)
{
if (!terminal.is_nil())
{
auto& n = nodes_.at(terminal);
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(true);
else
throw std::runtime_error("Tasks can only trigger abort!");
abort_terminal_ = -1;
for (std::size_t i = 0; i < terminals_.size(); ++i)
{
const boost::uuids::uuid& uuid = terminals_[i];
if (uuid == terminal)
{
abort_terminal_ = static_cast<int>(i);
auto& n = nodes_.at(terminal);
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(true);
else
throw std::runtime_error("Tasks can only trigger abort!");

break;
}
}
if (abort_terminal_ < 0)
throw std::runtime_error("Task with uuid: " + boost::uuids::to_string(terminal) + " is not a terminal node");
}
else
{
for (const auto& terminal : terminals_)
abort_terminal_ = -1;
for (const auto& t : terminals_)
{
auto& n = nodes_.at(terminal);
auto& n = nodes_.at(t);
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(false);
}
Expand All @@ -342,6 +356,7 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index)
{
if (terminal_index >= 0)
{
abort_terminal_ = terminal_index;
auto& n = nodes_.at(terminals_.at(static_cast<std::size_t>(terminal_index)));
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(true);
Expand All @@ -350,6 +365,7 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index)
}
else
{
abort_terminal_ = -1;
for (const auto& terminal : terminals_)
{
auto& n = nodes_.at(terminal);
Expand All @@ -359,6 +375,16 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index)
}
}

boost::uuids::uuid TaskComposerGraph::getAbortTerminal() const
{
if (abort_terminal_ >= 0)
return terminals_.at(static_cast<std::size_t>(abort_terminal_));

return {};
}

int TaskComposerGraph::getAbortTerminalIndex() const { return abort_terminal_; }

std::pair<bool, std::string> TaskComposerGraph::isValid() const
{
int root_node_cnt{ 0 };
Expand Down Expand Up @@ -409,6 +435,7 @@ TaskComposerGraph::dump(std::ostream& os,
<< "\\nUUID: " << uuid_str_ << "\\l";
os << "Inputs:\\l" << input_keys_;
os << "Outputs:\\l" << output_keys_;
os << "Abort Terminal: " << abort_terminal_ << "\\l";
os << "Conditional: " << ((conditional_) ? "True" : "False") << "\\l";
if (getType() == TaskComposerNodeType::PIPELINE || getType() == TaskComposerNodeType::GRAPH)
{
Expand Down Expand Up @@ -436,6 +463,7 @@ TaskComposerGraph::dump(std::ostream& os,
<< "\\l";
os << "Inputs:\\l" << input_keys;
os << "Outputs:\\l" << output_keys;
os << "Abort Terminal: " << static_cast<const TaskComposerGraph&>(*node).abort_terminal_ << "\\l";
os << "Conditional: " << ((node->isConditional()) ? "True" : "False") << "\\l";
if (it != results_map.end())
os << "Time: " << std::fixed << std::setprecision(3) << it->second->elapsed_time << "s\\l";
Expand Down Expand Up @@ -506,6 +534,7 @@ bool TaskComposerGraph::operator==(const TaskComposerGraph& rhs) const
}
}
equal &= (terminals_ == rhs.terminals_);
equal &= (abort_terminal_ == rhs.abort_terminal_);
equal &= TaskComposerNode::operator==(rhs);
return equal;
}
Expand All @@ -519,6 +548,7 @@ void TaskComposerGraph::serialize(Archive& ar, const unsigned int /*version*/)
{
ar& boost::serialization::make_nvp("nodes", nodes_);
ar& boost::serialization::make_nvp("terminals", terminals_);
ar& boost::serialization::make_nvp("abort_terminal", abort_terminal_);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(TaskComposerNode);
}

Expand Down
2 changes: 2 additions & 0 deletions tesseract_task_composer/core/src/task_composer_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ bool TaskComposerNode::operator==(const TaskComposerNode& rhs) const
{
bool equal = true;
equal &= name_ == rhs.name_;
equal &= ns_ == rhs.ns_;
equal &= type_ == rhs.type_;
equal &= uuid_ == rhs.uuid_;
equal &= uuid_str_ == rhs.uuid_str_;
Expand All @@ -488,6 +489,7 @@ template <class Archive>
void TaskComposerNode::serialize(Archive& ar, const unsigned int /*version*/)
{
ar& boost::serialization::make_nvp("name", name_);
ar& boost::serialization::make_nvp("ns", ns_);
ar& boost::serialization::make_nvp("type", type_);
ar& boost::serialization::make_nvp("uuid", uuid_);
ar& boost::serialization::make_nvp("uuid_str", uuid_str_);
Expand Down
38 changes: 37 additions & 1 deletion tesseract_task_composer/core/src/task_composer_node_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH
TESSERACT_COMMON_IGNORE_WARNINGS_POP

#include <tesseract_task_composer/core/task_composer_node_info.h>
#include <tesseract_task_composer/core/task_composer_node.h>
#include <tesseract_task_composer/core/task_composer_graph.h>

namespace tesseract_planning
{
Expand All @@ -48,9 +48,20 @@ TaskComposerNodeInfo::TaskComposerNodeInfo(const TaskComposerNode& node)
, ns(node.ns_)
, uuid(node.uuid_)
, parent_uuid(node.parent_uuid_)
, type(node.type_)
, type_hash_code(std::type_index(typeid(node)).hash_code())
, conditional(node.conditional_)
, inbound_edges(node.inbound_edges_)
, outbound_edges(node.outbound_edges_)
, input_keys(node.input_keys_)
, output_keys(node.output_keys_)
{
if (type == TaskComposerNodeType::GRAPH || type == TaskComposerNodeType::PIPELINE)
{
const auto& graph = static_cast<const TaskComposerGraph&>(node);
terminals = graph.getTerminals();
abort_terminal = graph.getAbortTerminalIndex();
}
}

bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const
Expand All @@ -62,6 +73,9 @@ bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const
equal &= ns == rhs.ns;
equal &= uuid == rhs.uuid;
equal &= parent_uuid == rhs.parent_uuid;
equal &= type == rhs.type;
equal &= type_hash_code == rhs.type_hash_code;
equal &= conditional == rhs.conditional;
equal &= return_value == rhs.return_value;
equal &= status_code == rhs.status_code;
equal &= status_message == rhs.status_message;
Expand All @@ -71,6 +85,8 @@ bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const
equal &= tesseract_common::isIdentical(outbound_edges, rhs.outbound_edges, true);
equal &= input_keys == rhs.input_keys;
equal &= output_keys == rhs.output_keys;
equal &= terminals == rhs.terminals;
equal &= abort_terminal == rhs.abort_terminal;
equal &= color == rhs.color;
equal &= dotgraph == rhs.dotgraph;
equal &= data_storage == rhs.data_storage;
Expand All @@ -89,6 +105,9 @@ void TaskComposerNodeInfo::serialize(Archive& ar, const unsigned int /*version*/
ar& boost::serialization::make_nvp("ns", ns);
ar& boost::serialization::make_nvp("uuid", uuid);
ar& boost::serialization::make_nvp("parent_uuid", parent_uuid);
ar& boost::serialization::make_nvp("type", type);
ar& boost::serialization::make_nvp("type_hash_code", type_hash_code);
ar& boost::serialization::make_nvp("conditional", conditional);
ar& boost::serialization::make_nvp("return_value", return_value);
ar& boost::serialization::make_nvp("status_code", status_code);
ar& boost::serialization::make_nvp("status_message", status_message);
Expand All @@ -99,6 +118,8 @@ void TaskComposerNodeInfo::serialize(Archive& ar, const unsigned int /*version*/
ar& boost::serialization::make_nvp("outbound_edges", outbound_edges);
ar& boost::serialization::make_nvp("input_keys", input_keys);
ar& boost::serialization::make_nvp("output_keys", output_keys);
ar& boost::serialization::make_nvp("terminals", terminals);
ar& boost::serialization::make_nvp("abort_terminal", abort_terminal);
ar& boost::serialization::make_nvp("color", color);
ar& boost::serialization::make_nvp("dotgraph", dotgraph);
ar& boost::serialization::make_nvp("data_storage", data_storage);
Expand Down Expand Up @@ -178,6 +199,18 @@ TaskComposerNodeInfoContainer::find(const std::function<bool(const TaskComposerN
return results;
}

void TaskComposerNodeInfoContainer::setRootNode(const boost::uuids::uuid& node_uuid)
{
std::unique_lock<std::shared_mutex> lock(mutex_);
root_node_ = node_uuid;
}

boost::uuids::uuid TaskComposerNodeInfoContainer::getRootNode() const
{
std::shared_lock<std::shared_mutex> lock(mutex_);
return root_node_;
}

void TaskComposerNodeInfoContainer::setAborted(const boost::uuids::uuid& node_uuid)
{
assert(!node_uuid.is_nil());
Expand Down Expand Up @@ -264,6 +297,8 @@ bool TaskComposerNodeInfoContainer::operator==(const TaskComposerNodeInfoContain
std::scoped_lock lock{ lhs_lock, rhs_lock };

bool equal = true;
equal &= root_node_ == rhs.root_node_;
equal &= aborting_node_ == rhs.aborting_node_;
auto equality = [](const TaskComposerNodeInfo::UPtr& p1, const TaskComposerNodeInfo::UPtr& p2) {
return (p1 && p2 && *p1 == *p2) || (!p1 && !p2);
};
Expand All @@ -285,6 +320,7 @@ template <class Archive>
void TaskComposerNodeInfoContainer::serialize(Archive& ar, const unsigned int /*version*/)
{
std::unique_lock<std::shared_mutex> lock(mutex_);
ar& BOOST_SERIALIZATION_NVP(root_node_);
ar& BOOST_SERIALIZATION_NVP(aborting_node_);
ar& BOOST_SERIALIZATION_NVP(info_map_);
}
Expand Down
2 changes: 0 additions & 2 deletions tesseract_task_composer/core/src/task_composer_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ std::unique_ptr<TaskComposerNodeInfo> TaskComposerPipeline::runImpl(TaskComposer
{
timer.stop();
auto info = std::make_unique<TaskComposerNodeInfo>(*this);
info->input_keys = input_keys_;
info->output_keys = output_keys_;
info->return_value = static_cast<int>(i);
info->color = node_info->color;
info->status_code = node_info->status_code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ class MotionPlannerTask : public TaskComposerTask
return info;
}

std::shared_ptr<const tesseract_environment::Environment> env =
env_poly.template as<std::shared_ptr<const tesseract_environment::Environment>>()->clone();
info->data_storage.setData("environment", env);
auto env = env_poly.template as<std::shared_ptr<const tesseract_environment::Environment>>();

auto input_data_poly = getData(*context.data_storage, INOUT_PROGRAM_PORT);
if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction)))
Expand Down
Loading

0 comments on commit 264b3d4

Please sign in to comment.