diff --git a/include/tensorwrapper/buffer/buffer_base.hpp b/include/tensorwrapper/buffer/buffer_base.hpp index ac61712c..9503e94d 100644 --- a/include/tensorwrapper/buffer/buffer_base.hpp +++ b/include/tensorwrapper/buffer/buffer_base.hpp @@ -16,7 +16,9 @@ #pragma once #include +#include #include + namespace tensorwrapper::buffer { /** @brief Common base class for all buffer objects. @@ -35,6 +37,9 @@ class BufferBase : public detail_::PolymorphicBase { /// Type all buffers inherit from using buffer_base_type = typename my_base_type::base_type; + /// Type of a mutable reference to a buffer_base_type object + using buffer_base_reference = typename my_base_type::base_reference; + /// Type of a read-only reference to a buffer_base_type object using const_buffer_base_reference = typename my_base_type::const_base_reference; @@ -54,6 +59,18 @@ class BufferBase : public detail_::PolymorphicBase { /// Type of a pointer to the layout using layout_pointer = typename layout_type::layout_pointer; + /// Type of labels for making a labeled buffer + using label_type = std::string; + + /// Type of a labeled buffer + using labeled_buffer_type = dsl::Labeled; + + /// Type of a labeled read-only buffer (n.b. labels are mutable) + using labeled_const_buffer_type = dsl::Labeled; + + /// Type of a read-only reference to a labeled_buffer_type object + using const_labeled_buffer_reference = const labeled_buffer_type&; + // ------------------------------------------------------------------------- // -- Accessors // ------------------------------------------------------------------------- @@ -85,10 +102,53 @@ class BufferBase : public detail_::PolymorphicBase { return *m_layout_; } + // ------------------------------------------------------------------------- + // -- BLAS Operations + // ------------------------------------------------------------------------- + + buffer_base_reference addition_assignment( + label_type this_labels, const_labeled_buffer_reference rhs) { + return addition_assignment_(std::move(this_labels), rhs); + } + + buffer_base_pointer addition(label_type this_labels, + const_labeled_buffer_reference rhs) const { + auto pthis = clone(); + pthis->addition_assignment(std::move(this_labels), rhs); + return pthis; + } + // ------------------------------------------------------------------------- // -- Utility methods // ------------------------------------------------------------------------- + /** @brief Associates labels with the modes of *this. + * + * This method is used to create a labeled buffer object by pairing *this + * with the provided labels. The resulting object is capable of being + * composed via the DSL. + * + * @param[in] labels The indices to associate with the modes of *this. + * + * @return A DSL term pairing *this with @p labels. + * + * @throw None No throw guarantee. + */ + labeled_buffer_type operator()(label_type labels); + + /** @brief Associates labels with the modes of *this. + * + * This method is the same as the non-const version except that the result + * contains a read-only reference to *this. + * + * @param[in] labels The labels to associate with *this. + * + * @return A DSL term pairing *this with @p labels. + * + * @throw None No throw guarantee. + */ + labeled_const_buffer_type operator()(label_type labels) const; + /** @brief Is *this value equal to @p rhs? * * Two BufferBase objects are value equal if the layouts they contain are @@ -186,6 +246,12 @@ class BufferBase : public detail_::PolymorphicBase { return *this; } + /// Derived class should overwrite to implement addition_assignment + virtual buffer_base_reference addition_assignment_( + label_type this_labels, const_labeled_buffer_reference rhs) { + throw std::runtime_error("Addition assignment NYI"); + } + private: /// Throws std::runtime_error when there is no layout void assert_layout_() const { diff --git a/include/tensorwrapper/buffer/eigen.hpp b/include/tensorwrapper/buffer/eigen.hpp index f4087d4f..faad3937 100644 --- a/include/tensorwrapper/buffer/eigen.hpp +++ b/include/tensorwrapper/buffer/eigen.hpp @@ -16,7 +16,7 @@ #pragma once #include -#include +#include namespace tensorwrapper::buffer { @@ -27,7 +27,7 @@ namespace tensorwrapper::buffer { * */ template -class Eigen : public EigenAny { +class Eigen : public Replicated { private: /// Type of *this using my_type = Eigen; @@ -39,7 +39,9 @@ class Eigen : public EigenAny { /// Pull in base class's types using typename my_base_type::buffer_base_pointer; using typename my_base_type::const_buffer_base_reference; + using typename my_base_type::const_labeled_buffer_reference; using typename my_base_type::const_layout_reference; + using typename my_base_type::label_type; /// Type of a rank @p Rank tensor using floats of type @p FloatType using data_type = eigen::data_type; @@ -180,29 +182,31 @@ class Eigen : public EigenAny { return my_base_type::are_equal_impl_(rhs); } - EigenAny& add_(const EigenAny& other) override { - auto& down_other = downcast_(other); - m_tensor_ += down_other.value(); - return *this; - } - - template - static Eigen& downcast_(EigenAny& other) { - auto pother = dynamic_cast*>(&other); - if(pother == nullptr) throw std::runtime_error("Not convertible"); - return *pother; - } - - template - static const Eigen& downcast_(const EigenAny& other) { - auto pother = dynamic_cast*>(&other); - if(pother == nullptr) throw std::runtime_error("Not convertible"); - return *pother; - } + /// Implements addition_assignment by rebinding rhs + buffer_base_reference addition_assignment_( + label_type this_labels, const_labeled_buffer_reference rhs) override; private: /// The actual Eigen tensor data_type m_tensor_; }; +#define DECLARE_EIGEN_BUFFER(RANK) \ + extern template class Eigen; \ + extern template class Eigen + +DECLARE_EIGEN_BUFFER(0); +DECLARE_EIGEN_BUFFER(1); +DECLARE_EIGEN_BUFFER(2); +DECLARE_EIGEN_BUFFER(3); +DECLARE_EIGEN_BUFFER(4); +DECLARE_EIGEN_BUFFER(5); +DECLARE_EIGEN_BUFFER(6); +DECLARE_EIGEN_BUFFER(7); +DECLARE_EIGEN_BUFFER(8); +DECLARE_EIGEN_BUFFER(9); +DECLARE_EIGEN_BUFFER(10); + +#undef DECLARE_EIGEN_BUFFER + } // namespace tensorwrapper::buffer diff --git a/src/tensorwrapper/buffer/buffer_base.cpp b/src/tensorwrapper/buffer/buffer_base.cpp new file mode 100644 index 00000000..23360b7f --- /dev/null +++ b/src/tensorwrapper/buffer/buffer_base.cpp @@ -0,0 +1,15 @@ +#include + +namespace tensorwrapper::buffer { + +typename BufferBase::labeled_buffer_type BufferBase::operator()( + label_type labels) { + return labeled_buffer_type(*this, std::move(labels)); +} + +typename BufferBase::labeled_const_buffer_type BufferBase::operator()( + label_type labels) const { + return labeled_const_buffer_type(*this, std::move(labels)); +} + +} // namespace tensorwrapper::buffer \ No newline at end of file diff --git a/src/tensorwrapper/buffer/eigen.cpp b/src/tensorwrapper/buffer/eigen.cpp new file mode 100644 index 00000000..902e22ba --- /dev/null +++ b/src/tensorwrapper/buffer/eigen.cpp @@ -0,0 +1,49 @@ +#include +#include + +namespace tensorwrapper::buffer { + +#define TPARAMS template +#define EIGEN Eigen + +TPARAMS +typename EIGEN::buffer_base_reference EIGEN::addition_assignment_( + label_type this_labels, const_labeled_buffer_reference rhs) { + using allocator_type = allocator::Eigen; + + if(this_labels != rhs.rhs()) + throw std::runtime_error("Labels must match (for now)!"); + + if(layout() != rhs.lhs().layout()) + throw std::runtime_error("Layouts must be the same (for now)"); + + const auto& rhs_downcasted = allocator_type::rebind(rhs.lhs()); + + m_tensor_ += rhs_downcasted.value(); + + // TODO layouts + return *this; +} + +#undef EIGEN +#undef TPARAMS + +#define DEFINE_EIGEN_BUFFER(RANK) \ + template class Eigen; \ + template class Eigen + +DEFINE_EIGEN_BUFFER(0); +DEFINE_EIGEN_BUFFER(1); +DEFINE_EIGEN_BUFFER(2); +DEFINE_EIGEN_BUFFER(3); +DEFINE_EIGEN_BUFFER(4); +DEFINE_EIGEN_BUFFER(5); +DEFINE_EIGEN_BUFFER(6); +DEFINE_EIGEN_BUFFER(7); +DEFINE_EIGEN_BUFFER(8); +DEFINE_EIGEN_BUFFER(9); +DEFINE_EIGEN_BUFFER(10); + +#undef DEFINE_EIGEN_BUFFER + +} // namespace tensorwrapper::buffer \ No newline at end of file diff --git a/src/tensorwrapper/dsl/executor/eigen.hpp b/src/tensorwrapper/dsl/executor/eigen.hpp index bc909c58..79d156c2 100644 --- a/src/tensorwrapper/dsl/executor/eigen.hpp +++ b/src/tensorwrapper/dsl/executor/eigen.hpp @@ -15,7 +15,6 @@ */ #pragma once -#include "detail_/eigen_dispatcher.hpp" #include namespace tensorwrapper::dsl::executor { diff --git a/tests/cxx/unit_tests/tensorwrapper/allocator/eigen.cpp b/tests/cxx/unit_tests/tensorwrapper/allocator/eigen.cpp index 8ac06fdf..ef83631c 100644 --- a/tests/cxx/unit_tests/tensorwrapper/allocator/eigen.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/allocator/eigen.cpp @@ -17,6 +17,7 @@ #include "../helpers.hpp" #include #include +#include #include using namespace tensorwrapper; @@ -124,6 +125,29 @@ TEMPLATE_TEST_CASE("EigenAllocator", "", float, double) { REQUIRE_THROWS_AS(scalar_alloc.allocate(vector_layout), except_t); } + SECTION("can_rebind") { + REQUIRE(scalar_alloc.can_rebind(scalar_corr)); + REQUIRE_FALSE(scalar_alloc.can_rebind(vector_corr)); + } + + SECTION("rebind(non-const)") { + using type = typename scalar_alloc_type::buffer_base_reference; + type scalar_base = scalar_corr; + auto& eigen_buffer = scalar_alloc.rebind(scalar_base); + REQUIRE(&eigen_buffer == &scalar_corr); + REQUIRE_THROWS_AS(scalar_alloc.rebind(vector_corr), std::runtime_error); + } + + SECTION("rebind(const)") { + using type = typename scalar_alloc_type::const_buffer_base_reference; + type scalar_base = scalar_corr; + auto& eigen_buffer = scalar_alloc.rebind(scalar_base); + REQUIRE(&eigen_buffer == &scalar_corr); + + type vector_base = vector_corr; + REQUIRE_THROWS_AS(scalar_alloc.rebind(vector_base), std::runtime_error); + } + SECTION("operator==") { REQUIRE(scalar_alloc == scalar_alloc_type(rv)); REQUIRE_FALSE(scalar_alloc == vector_alloc); diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp index f17e1aa6..fb6e379b 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp @@ -67,6 +67,26 @@ TEST_CASE("BufferBase") { REQUIRE(vector_base.layout().are_equal(vector_layout)); } + SECTION("operator()(std::string)") { + auto labeled_scalar = scalar_base(""); + REQUIRE(labeled_scalar.lhs().are_equal(scalar_base)); + REQUIRE(labeled_scalar.rhs() == ""); + + auto labeled_vector = vector_base("i"); + REQUIRE(labeled_vector.lhs().are_equal(vector_base)); + REQUIRE(labeled_vector.rhs() == "i"); + } + + SECTION("operator()(std::string) const") { + auto labeled_scalar = std::as_const(scalar_base)(""); + REQUIRE(labeled_scalar.lhs().are_equal(scalar_base)); + REQUIRE(labeled_scalar.rhs() == ""); + + auto labeled_vector = std::as_const(vector_base)("i"); + REQUIRE(labeled_vector.lhs().are_equal(vector_base)); + REQUIRE(labeled_vector.rhs() == "i"); + } + SECTION("operator==") { // Defaulted layout == defaulted layout REQUIRE(defaulted_base == scalar_buffer()); diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp index 73572c8e..a4bb96c6 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp @@ -159,6 +159,43 @@ TEMPLATE_TEST_CASE("Eigen", "", float, double) { REQUIRE(pscalar.are_equal(scalar2)); REQUIRE_FALSE(pmatrix.are_equal(scalar2)); } + + SECTION("addition_assignment") { + SECTION("scalar") { + scalar_buffer scalar2(eigen_scalar, scalar_layout); + scalar2.value()() = 42.0; + + auto s = scalar(""); + auto pscalar2 = &(scalar2.addition_assignment("", s)); + + scalar_buffer scalar_corr(eigen_scalar, scalar_layout); + scalar_corr.value()() = 43.0; + REQUIRE(pscalar2 == &scalar2); + REQUIRE(scalar2 == scalar_corr); + } + + SECTION("vector") { + vector_buffer vector2(eigen_vector, vector_layout); + + auto vi = vector("i"); + auto pvector2 = &(vector2.addition_assignment("i", vi)); + + vector_buffer vector_corr(eigen_vector, vector_layout); + vector_corr.value()(0) = 2.0; + vector_corr.value()(1) = 4.0; + + REQUIRE(pvector2 == &vector2); + REQUIRE(vector2 == vector_corr); + + // Labels must match (for now) + REQUIRE_THROWS_AS(vector2.addition_assignment("j", vi), + std::runtime_error); + } + + // Can't cast + REQUIRE_THROWS_AS(vector.addition_assignment("", scalar("")), + std::runtime_error); + } } } }