From f3d0558459def45cdbbb43a9e1d632a15ba0dd9d Mon Sep 17 00:00:00 2001 From: lia <167905060+lia-viam@users.noreply.github.com> Date: Thu, 27 Jun 2024 12:20:02 -0400 Subject: [PATCH] RSDK-8031 Pose tracker component (#258) --- src/viam/api/CMakeLists.txt | 8 ++ src/viam/sdk/CMakeLists.txt | 4 + src/viam/sdk/components/pose_tracker.cpp | 16 ++++ src/viam/sdk/components/pose_tracker.hpp | 75 +++++++++++++++++++ .../private/pose_tracker_client.cpp | 49 ++++++++++++ .../private/pose_tracker_client.hpp | 43 +++++++++++ .../private/pose_tracker_server.cpp | 60 +++++++++++++++ .../private/pose_tracker_server.hpp | 44 +++++++++++ src/viam/sdk/registry/registry.cpp | 3 + src/viam/sdk/tests/CMakeLists.txt | 2 + .../sdk/tests/mocks/mock_pose_tracker.cpp | 50 +++++++++++++ .../sdk/tests/mocks/mock_pose_tracker.hpp | 33 ++++++++ src/viam/sdk/tests/test_pose_tracker.cpp | 66 ++++++++++++++++ 13 files changed, 453 insertions(+) create mode 100644 src/viam/sdk/components/pose_tracker.cpp create mode 100644 src/viam/sdk/components/pose_tracker.hpp create mode 100644 src/viam/sdk/components/private/pose_tracker_client.cpp create mode 100644 src/viam/sdk/components/private/pose_tracker_client.hpp create mode 100644 src/viam/sdk/components/private/pose_tracker_server.cpp create mode 100644 src/viam/sdk/components/private/pose_tracker_server.hpp create mode 100644 src/viam/sdk/tests/mocks/mock_pose_tracker.cpp create mode 100644 src/viam/sdk/tests/mocks/mock_pose_tracker.hpp create mode 100644 src/viam/sdk/tests/test_pose_tracker.cpp diff --git a/src/viam/api/CMakeLists.txt b/src/viam/api/CMakeLists.txt index 736ae63b2..1b33312e7 100644 --- a/src/viam/api/CMakeLists.txt +++ b/src/viam/api/CMakeLists.txt @@ -181,6 +181,10 @@ if (VIAMCPPSDK_USE_DYNAMIC_PROTOS) ${PROTO_GEN_DIR}/component/movementsensor/v1/movementsensor.grpc.pb.h ${PROTO_GEN_DIR}/component/movementsensor/v1/movementsensor.pb.cc ${PROTO_GEN_DIR}/component/movementsensor/v1/movementsensor.pb.h + ${PROTO_GEN_DIR}/component/posetracker/v1/pose_tracker.grpc.pb.cc + ${PROTO_GEN_DIR}/component/posetracker/v1/pose_tracker.grpc.pb.h + ${PROTO_GEN_DIR}/component/posetracker/v1/pose_tracker.pb.cc + ${PROTO_GEN_DIR}/component/posetracker/v1/pose_tracker.pb.h ${PROTO_GEN_DIR}/component/powersensor/v1/powersensor.grpc.pb.cc ${PROTO_GEN_DIR}/component/powersensor/v1/powersensor.grpc.pb.h ${PROTO_GEN_DIR}/component/powersensor/v1/powersensor.pb.cc @@ -302,6 +306,8 @@ target_sources(viamapi ${PROTO_GEN_DIR}/component/motor/v1/motor.pb.cc ${PROTO_GEN_DIR}/component/movementsensor/v1/movementsensor.grpc.pb.cc ${PROTO_GEN_DIR}/component/movementsensor/v1/movementsensor.pb.cc + ${PROTO_GEN_DIR}/component/posetracker/v1/pose_tracker.grpc.pb.cc + ${PROTO_GEN_DIR}/component/posetracker/v1/pose_tracker.pb.cc ${PROTO_GEN_DIR}/component/powersensor/v1/powersensor.grpc.pb.cc ${PROTO_GEN_DIR}/component/powersensor/v1/powersensor.pb.cc ${PROTO_GEN_DIR}/component/sensor/v1/sensor.grpc.pb.cc @@ -356,6 +362,8 @@ target_sources(viamapi ${PROTO_GEN_DIR}/../../viam/api/component/motor/v1/motor.pb.h ${PROTO_GEN_DIR}/../../viam/api/component/movementsensor/v1/movementsensor.grpc.pb.h ${PROTO_GEN_DIR}/../../viam/api/component/movementsensor/v1/movementsensor.pb.h + ${PROTO_GEN_DIR}/../../viam/api/component/posetracker/v1/pose_tracker.grpc.pb.h + ${PROTO_GEN_DIR}/../../viam/api/component/posetracker/v1/pose_tracker.pb.h ${PROTO_GEN_DIR}/../../viam/api/component/powersensor/v1/powersensor.grpc.pb.h ${PROTO_GEN_DIR}/../../viam/api/component/powersensor/v1/powersensor.pb.h ${PROTO_GEN_DIR}/../../viam/api/component/sensor/v1/sensor.grpc.pb.h diff --git a/src/viam/sdk/CMakeLists.txt b/src/viam/sdk/CMakeLists.txt index f6ef08215..5136d59e5 100644 --- a/src/viam/sdk/CMakeLists.txt +++ b/src/viam/sdk/CMakeLists.txt @@ -55,6 +55,7 @@ target_sources(viamsdk components/gripper.cpp components/motor.cpp components/movement_sensor.cpp + components/pose_tracker.cpp components/power_sensor.cpp components/private/arm_client.cpp components/private/arm_server.cpp @@ -76,6 +77,8 @@ target_sources(viamsdk components/private/motor_server.cpp components/private/movement_sensor_client.cpp components/private/movement_sensor_server.cpp + components/private/pose_tracker_client.cpp + components/private/pose_tracker_server.cpp components/private/power_sensor_client.cpp components/private/power_sensor_server.cpp components/private/sensor_client.cpp @@ -138,6 +141,7 @@ target_sources(viamsdk ../../viam/sdk/components/gripper.hpp ../../viam/sdk/components/motor.hpp ../../viam/sdk/components/movement_sensor.hpp + ../../viam/sdk/components/pose_tracker.hpp ../../viam/sdk/components/power_sensor.hpp ../../viam/sdk/components/sensor.hpp ../../viam/sdk/components/servo.hpp diff --git a/src/viam/sdk/components/pose_tracker.cpp b/src/viam/sdk/components/pose_tracker.cpp new file mode 100644 index 000000000..0de662b44 --- /dev/null +++ b/src/viam/sdk/components/pose_tracker.cpp @@ -0,0 +1,16 @@ +#include + +#include + +namespace viam { +namespace sdk { +API PoseTracker::api() const { + return API::get(); +} + +API API::traits::api() { + return {kRDK, kComponent, "pose_tracker"}; +} + +} // namespace sdk +} // namespace viam diff --git a/src/viam/sdk/components/pose_tracker.hpp b/src/viam/sdk/components/pose_tracker.hpp new file mode 100644 index 000000000..563853081 --- /dev/null +++ b/src/viam/sdk/components/pose_tracker.hpp @@ -0,0 +1,75 @@ +/// @file components/pose_tracker.hpp +/// +/// @brief Defines a `PoseTracker` component +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace viam { +namespace sdk { + +/// @defgroup PoseTracker Classes related to the PoseTracker component. + +/// @class PoseTracker pose_tracker.hpp "components/pose_tracker.hpp" +/// @brief A `PoseTracker` represents a physical pose or motion tracking device. +/// +/// This acts as an abstract base class for any drivers representing specific +/// pose tracker implementations +class PoseTracker : public Component { + public: + using pose_map = std::unordered_map; + + API api() const override; + + /// @brief Get the poses of each body tracked by the pose tracker. + /// @param tracker_name The name of the pose tracker. + /// @param body_names Names of bodies whose poses are being requested. If the vector is empty + /// then all available poses are returned. + /// @return A mapping of each body to its pose. + inline pose_map get_poses(const std::vector& body_names) { + return get_poses(body_names, {}); + } + + /// @brief Get the poses of each body tracked by the pose tracker. + /// @param tracker_name The name of the pose tracker. + /// @param body_names Names of bodies whose poses are being requested. If the vector is empty + /// then all available poses are returned. + /// @param extra Any additional arguments to the method. + /// @return A mapping of each body to its pose. + virtual pose_map get_poses(const std::vector& body_names, + const AttributeMap& extra) = 0; + + /// @brief Send/receive arbitrary commands to the resource. + /// @param Command the command to execute. + /// @return The result of the executed command. + virtual AttributeMap do_command(const AttributeMap& command) = 0; + + /// @brief Returns `GeometryConfig`s associated with the calling pose tracker + inline std::vector get_geometries() { + return get_geometries({}); + } + + /// @brief Returns `GeometryConfig`s associated with the calling pose tracker + /// @param extra Any additional arguments to the method + virtual std::vector get_geometries(const AttributeMap& extra) = 0; + + protected: + using Component::Component; +}; + +template <> +struct API::traits { + static API api(); +}; + +} // namespace sdk +} // namespace viam diff --git a/src/viam/sdk/components/private/pose_tracker_client.cpp b/src/viam/sdk/components/private/pose_tracker_client.cpp new file mode 100644 index 000000000..ccbea7edf --- /dev/null +++ b/src/viam/sdk/components/private/pose_tracker_client.cpp @@ -0,0 +1,49 @@ +#include + +#include +#include +#include + +#include + +namespace viam { +namespace sdk { +namespace impl { + +PoseTrackerClient::PoseTrackerClient(std::string name, std::shared_ptr channel) + : PoseTracker(std::move(name)), + stub_(viam::component::posetracker::v1::PoseTrackerService::NewStub(channel)), + channel_(std::move(channel)) {} + +PoseTracker::pose_map PoseTrackerClient::get_poses(const std::vector& body_names, + const AttributeMap&) { + return make_client_helper(this, *stub_, &StubType::GetPoses) + .with([&](viam::component::posetracker::v1::GetPosesRequest& request) { + *request.mutable_body_names() = {body_names.begin(), body_names.end()}; + }) + .invoke([](const viam::component::posetracker::v1::GetPosesResponse& response) { + PoseTracker::pose_map result; + + for (const auto& pair : response.body_poses()) { + result.emplace(pair.first, pose_in_frame::from_proto(pair.second)); + } + + return result; + }); +} + +std::vector PoseTrackerClient::get_geometries(const AttributeMap& extra) { + return make_client_helper(this, *stub_, &StubType::GetGeometries) + .with(extra) + .invoke([](auto& response) { return GeometryConfig::from_proto(response); }); +} + +AttributeMap PoseTrackerClient::do_command(const AttributeMap& command) { + return make_client_helper(this, *stub_, &StubType::DoCommand) + .with([&](auto& request) { *request.mutable_command() = map_to_struct(command); }) + .invoke([](auto& response) { return struct_to_map(response.result()); }); +} + +} // namespace impl +} // namespace sdk +} // namespace viam diff --git a/src/viam/sdk/components/private/pose_tracker_client.hpp b/src/viam/sdk/components/private/pose_tracker_client.hpp new file mode 100644 index 000000000..f21854d71 --- /dev/null +++ b/src/viam/sdk/components/private/pose_tracker_client.hpp @@ -0,0 +1,43 @@ +/// @file components/private/pose_tracker_client.hpp +/// +/// @brief Implements a gRPC client for the `PoseTracker` component +#pragma once + +#include + +#include + +#include + +namespace viam { +namespace sdk { +namespace impl { + +/// @class PoseTrackerClient +/// @brief gRPC client implementation of a `PoseTracker` component. +/// @ingroup PoseTracker +class PoseTrackerClient : public PoseTracker { + public: + using interface_type = PoseTracker; + + PoseTrackerClient(std::string name, std::shared_ptr channel); + + PoseTracker::pose_map get_poses(const std::vector& body_names, + const AttributeMap& extra) override; + + AttributeMap do_command(const AttributeMap& command) override; + + std::vector get_geometries(const AttributeMap& extra) override; + + using PoseTracker::get_geometries; + using PoseTracker::get_poses; + + private: + using StubType = viam::component::posetracker::v1::PoseTrackerService::StubInterface; + std::unique_ptr stub_; + std::shared_ptr channel_; +}; + +} // namespace impl +} // namespace sdk +} // namespace viam diff --git a/src/viam/sdk/components/private/pose_tracker_server.cpp b/src/viam/sdk/components/private/pose_tracker_server.cpp new file mode 100644 index 000000000..9e0093ea4 --- /dev/null +++ b/src/viam/sdk/components/private/pose_tracker_server.cpp @@ -0,0 +1,60 @@ +#include + +#include + +#include +#include +#include +#include +#include + +namespace viam { +namespace sdk { +namespace impl { + +PoseTrackerServer::PoseTrackerServer(std::shared_ptr manager) + : ResourceServer(std::move(manager)) {} + +::grpc::Status PoseTrackerServer::GetPoses( + ::grpc::ServerContext*, + const ::viam::component::posetracker::v1::GetPosesRequest* request, + ::viam::component::posetracker::v1::GetPosesResponse* response) noexcept { + return make_service_helper( + "PoseTrackerServer::GetPoses", this, request)([&](auto& helper, auto& pose_tracker) { + const PoseTracker::pose_map result = pose_tracker->get_poses( + {request->body_names().begin(), request->body_names().end()}, helper.getExtra()); + + for (const auto& pair : result) { + response->mutable_body_poses()->insert({pair.first, pair.second.to_proto()}); + } + }); +} + +::grpc::Status PoseTrackerServer::DoCommand( + grpc::ServerContext*, + const viam::common::v1::DoCommandRequest* request, + viam::common::v1::DoCommandResponse* response) noexcept { + return make_service_helper( + "PoseTrackerServer::DoCommand", this, request)([&](auto&, auto& pose_tracker) { + const AttributeMap result = pose_tracker->do_command(struct_to_map(request->command())); + *response->mutable_result() = map_to_struct(result); + }); +} + +::grpc::Status PoseTrackerServer::GetGeometries( + ::grpc::ServerContext*, + const ::viam::common::v1::GetGeometriesRequest* request, + ::viam::common::v1::GetGeometriesResponse* response) noexcept { + return make_service_helper( + "PoseTrackerServer::GetGeometries", this, request)([&](auto& helper, auto& pose_tracker) { + const std::vector geometries = + pose_tracker->get_geometries(helper.getExtra()); + for (const auto& geometry : geometries) { + *response->mutable_geometries()->Add() = geometry.to_proto(); + } + }); +} + +} // namespace impl +} // namespace sdk +} // namespace viam diff --git a/src/viam/sdk/components/private/pose_tracker_server.hpp b/src/viam/sdk/components/private/pose_tracker_server.hpp new file mode 100644 index 000000000..3728c1ac5 --- /dev/null +++ b/src/viam/sdk/components/private/pose_tracker_server.hpp @@ -0,0 +1,44 @@ +/// @file components/private/pose_tracker_server.hpp +/// +/// @brief Implements a gRPC server for the `PoseTracker` component. +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace viam { +namespace sdk { +namespace impl { + +/// @class PoseTrackerServer +/// @brief gRPC server implementation of a `PoseTracker` component. +/// @ingroup PoseTracker +class PoseTrackerServer : public ResourceServer, + public viam::component::posetracker::v1::PoseTrackerService::Service { + public: + using interface_type = PoseTracker; + using service_type = component::posetracker::v1::PoseTrackerService; + explicit PoseTrackerServer(std::shared_ptr manager); + + ::grpc::Status GetPoses( + ::grpc::ServerContext* context, + const ::viam::component::posetracker::v1::GetPosesRequest* request, + ::viam::component::posetracker::v1::GetPosesResponse* response) noexcept override; + + ::grpc::Status DoCommand(::grpc::ServerContext* context, + const ::viam::common::v1::DoCommandRequest* request, + ::viam::common::v1::DoCommandResponse* response) noexcept override; + + ::grpc::Status GetGeometries( + ::grpc::ServerContext* context, + const ::viam::common::v1::GetGeometriesRequest* request, + ::viam::common::v1::GetGeometriesResponse* response) noexcept override; +}; +} // namespace impl +} // namespace sdk +} // namespace viam diff --git a/src/viam/sdk/registry/registry.cpp b/src/viam/sdk/registry/registry.cpp index 7cb38aca3..4e0819722 100644 --- a/src/viam/sdk/registry/registry.cpp +++ b/src/viam/sdk/registry/registry.cpp @@ -32,6 +32,8 @@ #include #include #include +#include +#include #include #include #include @@ -179,6 +181,7 @@ void register_resources() { Registry::register_resource(); Registry::register_resource(); Registry::register_resource(); + Registry::register_resource(); Registry::register_resource(); Registry::register_resource(); Registry::register_resource(); diff --git a/src/viam/sdk/tests/CMakeLists.txt b/src/viam/sdk/tests/CMakeLists.txt index 24122e537..b07f67243 100644 --- a/src/viam/sdk/tests/CMakeLists.txt +++ b/src/viam/sdk/tests/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources(viamsdk_test mocks/mock_motor.cpp mocks/mock_motion.cpp mocks/mock_movement_sensor.cpp + mocks/mock_pose_tracker.cpp mocks/mock_power_sensor.cpp mocks/mock_sensor.cpp mocks/mock_servo.cpp @@ -55,6 +56,7 @@ viamcppsdk_add_boost_test(test_mlmodel.cpp) viamcppsdk_add_boost_test(test_motor.cpp) viamcppsdk_add_boost_test(test_motion.cpp) viamcppsdk_add_boost_test(test_movement_sensor.cpp) +viamcppsdk_add_boost_test(test_pose_tracker.cpp) viamcppsdk_add_boost_test(test_power_sensor.cpp) viamcppsdk_add_boost_test(test_resource.cpp) viamcppsdk_add_boost_test(test_sensor.cpp) diff --git a/src/viam/sdk/tests/mocks/mock_pose_tracker.cpp b/src/viam/sdk/tests/mocks/mock_pose_tracker.cpp new file mode 100644 index 000000000..560c60f44 --- /dev/null +++ b/src/viam/sdk/tests/mocks/mock_pose_tracker.cpp @@ -0,0 +1,50 @@ +#include + +#include + +namespace viam { +namespace sdktests { +namespace pose_tracker { + +using namespace viam::sdk; + +PoseTracker::pose_map fake_poses() { + return { + {body1, pose_in_frame("", {{0, 0, 0}, {0, 0, 0}, 0})}, + {body2, pose_in_frame("", {{1, 2, 3}, {4, 5, 6}, 7})}, + }; +} + +std::shared_ptr MockPoseTracker::get_mock_pose_tracker() { + return std::make_shared("mock_pose_tracker"); +} + +PoseTracker::pose_map MockPoseTracker::get_poses(const std::vector& body_names, + const sdk::AttributeMap&) { + auto full_map = fake_poses(); + + if (body_names.empty()) + return full_map; + + PoseTracker::pose_map result; + for (const auto& pair : full_map) { + if (std::find(body_names.begin(), body_names.end(), pair.first) != body_names.end()) { + result.insert(pair); + } + } + + return result; +} + +AttributeMap MockPoseTracker::do_command(const AttributeMap& command) { + this->peek_do_command_command = command; + return command; +} + +std::vector MockPoseTracker::get_geometries(const sdk::AttributeMap&) { + return fake_geometries(); +} + +} // namespace pose_tracker +} // namespace sdktests +} // namespace viam diff --git a/src/viam/sdk/tests/mocks/mock_pose_tracker.hpp b/src/viam/sdk/tests/mocks/mock_pose_tracker.hpp new file mode 100644 index 000000000..903f31b7f --- /dev/null +++ b/src/viam/sdk/tests/mocks/mock_pose_tracker.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include + +namespace viam { +namespace sdktests { +namespace pose_tracker { + +// body names for the fake poses +constexpr const char* body1 = "b1"; +constexpr const char* body2 = "b2"; + +sdk::PoseTracker::pose_map fake_poses(); + +class MockPoseTracker : public sdk::PoseTracker { + public: + static std::shared_ptr get_mock_pose_tracker(); + + MockPoseTracker(std::string name) : sdk::PoseTracker(std::move(name)) {} + + sdk::PoseTracker::pose_map get_poses(const std::vector& body_names, + const sdk::AttributeMap& extra) override; + + sdk::AttributeMap do_command(const sdk::AttributeMap& command) override; + + std::vector get_geometries(const sdk::AttributeMap& extra) override; + + sdk::AttributeMap peek_do_command_command; +}; + +} // namespace pose_tracker +} // namespace sdktests +} // namespace viam diff --git a/src/viam/sdk/tests/test_pose_tracker.cpp b/src/viam/sdk/tests/test_pose_tracker.cpp new file mode 100644 index 000000000..cfaef7aec --- /dev/null +++ b/src/viam/sdk/tests/test_pose_tracker.cpp @@ -0,0 +1,66 @@ +#define BOOST_TEST_MODULE test module test_pose_tracker + +#include + +#include +#include +#include + +BOOST_TEST_DONT_PRINT_LOG_VALUE(std::vector) +BOOST_TEST_DONT_PRINT_LOG_VALUE(viam::sdk::PoseTracker::pose_map) + +namespace viam { +namespace sdktests { + +using namespace pose_tracker; +using namespace viam::sdk; + +BOOST_AUTO_TEST_SUITE(test_pose_tracker) + +BOOST_AUTO_TEST_CASE(mock_get_api) { + const MockPoseTracker pose_tracker("mock_pose_tracker"); + auto api = pose_tracker.api(); + auto static_api = API::get(); + + BOOST_CHECK_EQUAL(api, static_api); + BOOST_CHECK_EQUAL(static_api.resource_subtype(), "pose_tracker"); +} + +BOOST_AUTO_TEST_CASE(mock_get_poses) { + std::shared_ptr mock = MockPoseTracker::get_mock_pose_tracker(); + client_to_mock_pipeline(mock, [](PoseTracker& client) { + const PoseTracker::pose_map& fakes = fake_poses(); + BOOST_CHECK_EQUAL(fakes, client.get_poses({})); + BOOST_CHECK_EQUAL(fakes, client.get_poses({body1, body2})); + + const PoseTracker::pose_map& single = client.get_poses({body1}); + BOOST_CHECK_EQUAL(single.size(), 1); + BOOST_CHECK_EQUAL(single.at(body1), fakes.at(body1)); + }); +} + +BOOST_AUTO_TEST_CASE(test_do_command) { + std::shared_ptr mock = MockPoseTracker::get_mock_pose_tracker(); + client_to_mock_pipeline(mock, [](PoseTracker& client) { + AttributeMap expected = fake_map(); + + AttributeMap command = fake_map(); + AttributeMap result_map = client.do_command(command); + + ProtoType expected_pt = *(expected->at(std::string("test"))); + ProtoType result_pt = *(result_map->at(std::string("test"))); + BOOST_CHECK(result_pt == expected_pt); + }); +} + +BOOST_AUTO_TEST_CASE(test_get_geometries) { + std::shared_ptr mock = MockPoseTracker::get_mock_pose_tracker(); + client_to_mock_pipeline(mock, [](PoseTracker& client) { + const auto& geometries = client.get_geometries(); + BOOST_CHECK_EQUAL(geometries, fake_geometries()); + }); +} + +BOOST_AUTO_TEST_SUITE_END() +} // namespace sdktests +} // namespace viam