Skip to content

Commit

Permalink
addition_assignment tested
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanmrichard committed Dec 10, 2024
1 parent 3beda27 commit ede020c
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 22 deletions.
66 changes: 66 additions & 0 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#pragma once
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/layout/layout_base.hpp>

namespace tensorwrapper::buffer {

/** @brief Common base class for all buffer objects.
Expand All @@ -35,6 +37,9 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
/// Type all buffers inherit from
using buffer_base_type = typename my_base_type::base_type;

/// Type of a mutable reference to a buffer_base_type object
using buffer_base_reference = typename my_base_type::base_reference;

/// Type of a read-only reference to a buffer_base_type object
using const_buffer_base_reference =
typename my_base_type::const_base_reference;
Expand All @@ -54,6 +59,18 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
/// Type of a pointer to the layout
using layout_pointer = typename layout_type::layout_pointer;

/// Type of labels for making a labeled buffer
using label_type = std::string;

/// Type of a labeled buffer
using labeled_buffer_type = dsl::Labeled<buffer_base_type, label_type>;

/// Type of a labeled read-only buffer (n.b. labels are mutable)
using labeled_const_buffer_type = dsl::Labeled<const buffer_base_type>;

/// Type of a read-only reference to a labeled_buffer_type object
using const_labeled_buffer_reference = const labeled_buffer_type&;

// -------------------------------------------------------------------------
// -- Accessors
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -85,10 +102,53 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *m_layout_;
}

// -------------------------------------------------------------------------
// -- BLAS Operations
// -------------------------------------------------------------------------

buffer_base_reference addition_assignment(
label_type this_labels, const_labeled_buffer_reference rhs) {
return addition_assignment_(std::move(this_labels), rhs);
}

buffer_base_pointer addition(label_type this_labels,
const_labeled_buffer_reference rhs) const {
auto pthis = clone();
pthis->addition_assignment(std::move(this_labels), rhs);
return pthis;
}

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Associates labels with the modes of *this.
*
* This method is used to create a labeled buffer object by pairing *this
* with the provided labels. The resulting object is capable of being
* composed via the DSL.
*
* @param[in] labels The indices to associate with the modes of *this.
*
* @return A DSL term pairing *this with @p labels.
*
* @throw None No throw guarantee.
*/
labeled_buffer_type operator()(label_type labels);

/** @brief Associates labels with the modes of *this.
*
* This method is the same as the non-const version except that the result
* contains a read-only reference to *this.
*
* @param[in] labels The labels to associate with *this.
*
* @return A DSL term pairing *this with @p labels.
*
* @throw None No throw guarantee.
*/
labeled_const_buffer_type operator()(label_type labels) const;

/** @brief Is *this value equal to @p rhs?
*
* Two BufferBase objects are value equal if the layouts they contain are
Expand Down Expand Up @@ -186,6 +246,12 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *this;
}

/// Derived class should overwrite to implement addition_assignment
virtual buffer_base_reference addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) {
throw std::runtime_error("Addition assignment NYI");
}

private:
/// Throws std::runtime_error when there is no layout
void assert_layout_() const {
Expand Down
46 changes: 25 additions & 21 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once
#include <tensorwrapper/backends/eigen.hpp>
#include <tensorwrapper/buffer/eigen_any.hpp>
#include <tensorwrapper/buffer/replicated.hpp>

namespace tensorwrapper::buffer {

Expand All @@ -27,7 +27,7 @@ namespace tensorwrapper::buffer {
*
*/
template<typename FloatType, unsigned short Rank>
class Eigen : public EigenAny {
class Eigen : public Replicated {
private:
/// Type of *this
using my_type = Eigen<FloatType, Rank>;
Expand All @@ -39,7 +39,9 @@ class Eigen : public EigenAny {
/// Pull in base class's types
using typename my_base_type::buffer_base_pointer;
using typename my_base_type::const_buffer_base_reference;
using typename my_base_type::const_labeled_buffer_reference;
using typename my_base_type::const_layout_reference;
using typename my_base_type::label_type;

/// Type of a rank @p Rank tensor using floats of type @p FloatType
using data_type = eigen::data_type<FloatType, Rank>;
Expand Down Expand Up @@ -180,29 +182,31 @@ class Eigen : public EigenAny {
return my_base_type::are_equal_impl_<my_type>(rhs);
}

EigenAny& add_(const EigenAny& other) override {
auto& down_other = downcast_<FloatType, Rank>(other);
m_tensor_ += down_other.value();
return *this;
}

template<typename FloatType2, unsigned short Rank2>
static Eigen<FloatType2, N2>& downcast_(EigenAny& other) {
auto pother = dynamic_cast<Eigen<FloatType2, N2>*>(&other);
if(pother == nullptr) throw std::runtime_error("Not convertible");
return *pother;
}

template<typename FloatType2, unsigned short Rank2>
static const Eigen<FloatType2, N2>& downcast_(const EigenAny& other) {
auto pother = dynamic_cast<const Eigen<FloatType2, N2>*>(&other);
if(pother == nullptr) throw std::runtime_error("Not convertible");
return *pother;
}
/// Implements addition_assignment by rebinding rhs
buffer_base_reference addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) override;

private:
/// The actual Eigen tensor
data_type m_tensor_;
};

#define DECLARE_EIGEN_BUFFER(RANK) \
extern template class Eigen<float, RANK>; \
extern template class Eigen<double, RANK>

DECLARE_EIGEN_BUFFER(0);
DECLARE_EIGEN_BUFFER(1);
DECLARE_EIGEN_BUFFER(2);
DECLARE_EIGEN_BUFFER(3);
DECLARE_EIGEN_BUFFER(4);
DECLARE_EIGEN_BUFFER(5);
DECLARE_EIGEN_BUFFER(6);
DECLARE_EIGEN_BUFFER(7);
DECLARE_EIGEN_BUFFER(8);
DECLARE_EIGEN_BUFFER(9);
DECLARE_EIGEN_BUFFER(10);

#undef DECLARE_EIGEN_BUFFER

} // namespace tensorwrapper::buffer
15 changes: 15 additions & 0 deletions src/tensorwrapper/buffer/buffer_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <tensorwrapper/buffer/buffer_base.hpp>

namespace tensorwrapper::buffer {

typename BufferBase::labeled_buffer_type BufferBase::operator()(
label_type labels) {
return labeled_buffer_type(*this, std::move(labels));
}

typename BufferBase::labeled_const_buffer_type BufferBase::operator()(
label_type labels) const {
return labeled_const_buffer_type(*this, std::move(labels));
}

} // namespace tensorwrapper::buffer
49 changes: 49 additions & 0 deletions src/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include <tensorwrapper/allocator/eigen.hpp>
#include <tensorwrapper/buffer/eigen.hpp>

namespace tensorwrapper::buffer {

#define TPARAMS template<typename FloatType, unsigned short Rank>
#define EIGEN Eigen<FloatType, Rank>

TPARAMS
typename EIGEN::buffer_base_reference EIGEN::addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) {
using allocator_type = allocator::Eigen<FloatType, Rank>;

if(this_labels != rhs.rhs())
throw std::runtime_error("Labels must match (for now)!");

if(layout() != rhs.lhs().layout())
throw std::runtime_error("Layouts must be the same (for now)");

const auto& rhs_downcasted = allocator_type::rebind(rhs.lhs());

m_tensor_ += rhs_downcasted.value();

// TODO layouts
return *this;
}

#undef EIGEN
#undef TPARAMS

#define DEFINE_EIGEN_BUFFER(RANK) \
template class Eigen<float, RANK>; \
template class Eigen<double, RANK>

DEFINE_EIGEN_BUFFER(0);
DEFINE_EIGEN_BUFFER(1);
DEFINE_EIGEN_BUFFER(2);
DEFINE_EIGEN_BUFFER(3);
DEFINE_EIGEN_BUFFER(4);
DEFINE_EIGEN_BUFFER(5);
DEFINE_EIGEN_BUFFER(6);
DEFINE_EIGEN_BUFFER(7);
DEFINE_EIGEN_BUFFER(8);
DEFINE_EIGEN_BUFFER(9);
DEFINE_EIGEN_BUFFER(10);

#undef DEFINE_EIGEN_BUFFER

} // namespace tensorwrapper::buffer
1 change: 0 additions & 1 deletion src/tensorwrapper/dsl/executor/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/

#pragma once
#include "detail_/eigen_dispatcher.hpp"
#include <tensorwrapper/tensor/tensor_class.hpp>
namespace tensorwrapper::dsl::executor {

Expand Down
24 changes: 24 additions & 0 deletions tests/cxx/unit_tests/tensorwrapper/allocator/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../helpers.hpp"
#include <parallelzone/parallelzone.hpp>
#include <tensorwrapper/allocator/eigen.hpp>
#include <tensorwrapper/buffer/eigen.hpp>
#include <tensorwrapper/shape/smooth.hpp>

using namespace tensorwrapper;
Expand Down Expand Up @@ -124,6 +125,29 @@ TEMPLATE_TEST_CASE("EigenAllocator", "", float, double) {
REQUIRE_THROWS_AS(scalar_alloc.allocate(vector_layout), except_t);
}

SECTION("can_rebind") {
REQUIRE(scalar_alloc.can_rebind(scalar_corr));
REQUIRE_FALSE(scalar_alloc.can_rebind(vector_corr));
}

SECTION("rebind(non-const)") {
using type = typename scalar_alloc_type::buffer_base_reference;
type scalar_base = scalar_corr;
auto& eigen_buffer = scalar_alloc.rebind(scalar_base);
REQUIRE(&eigen_buffer == &scalar_corr);
REQUIRE_THROWS_AS(scalar_alloc.rebind(vector_corr), std::runtime_error);
}

SECTION("rebind(const)") {
using type = typename scalar_alloc_type::const_buffer_base_reference;
type scalar_base = scalar_corr;
auto& eigen_buffer = scalar_alloc.rebind(scalar_base);
REQUIRE(&eigen_buffer == &scalar_corr);

type vector_base = vector_corr;
REQUIRE_THROWS_AS(scalar_alloc.rebind(vector_base), std::runtime_error);
}

SECTION("operator==") {
REQUIRE(scalar_alloc == scalar_alloc_type(rv));
REQUIRE_FALSE(scalar_alloc == vector_alloc);
Expand Down
20 changes: 20 additions & 0 deletions tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ TEST_CASE("BufferBase") {
REQUIRE(vector_base.layout().are_equal(vector_layout));
}

SECTION("operator()(std::string)") {
auto labeled_scalar = scalar_base("");
REQUIRE(labeled_scalar.lhs().are_equal(scalar_base));
REQUIRE(labeled_scalar.rhs() == "");

auto labeled_vector = vector_base("i");
REQUIRE(labeled_vector.lhs().are_equal(vector_base));
REQUIRE(labeled_vector.rhs() == "i");
}

SECTION("operator()(std::string) const") {
auto labeled_scalar = std::as_const(scalar_base)("");
REQUIRE(labeled_scalar.lhs().are_equal(scalar_base));
REQUIRE(labeled_scalar.rhs() == "");

auto labeled_vector = std::as_const(vector_base)("i");
REQUIRE(labeled_vector.lhs().are_equal(vector_base));
REQUIRE(labeled_vector.rhs() == "i");
}

SECTION("operator==") {
// Defaulted layout == defaulted layout
REQUIRE(defaulted_base == scalar_buffer());
Expand Down
37 changes: 37 additions & 0 deletions tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,43 @@ TEMPLATE_TEST_CASE("Eigen", "", float, double) {
REQUIRE(pscalar.are_equal(scalar2));
REQUIRE_FALSE(pmatrix.are_equal(scalar2));
}

SECTION("addition_assignment") {
SECTION("scalar") {
scalar_buffer scalar2(eigen_scalar, scalar_layout);
scalar2.value()() = 42.0;

auto s = scalar("");
auto pscalar2 = &(scalar2.addition_assignment("", s));

scalar_buffer scalar_corr(eigen_scalar, scalar_layout);
scalar_corr.value()() = 43.0;
REQUIRE(pscalar2 == &scalar2);
REQUIRE(scalar2 == scalar_corr);
}

SECTION("vector") {
vector_buffer vector2(eigen_vector, vector_layout);

auto vi = vector("i");
auto pvector2 = &(vector2.addition_assignment("i", vi));

vector_buffer vector_corr(eigen_vector, vector_layout);
vector_corr.value()(0) = 2.0;
vector_corr.value()(1) = 4.0;

REQUIRE(pvector2 == &vector2);
REQUIRE(vector2 == vector_corr);

// Labels must match (for now)
REQUIRE_THROWS_AS(vector2.addition_assignment("j", vi),
std::runtime_error);
}

// Can't cast
REQUIRE_THROWS_AS(vector.addition_assignment("", scalar("")),
std::runtime_error);
}
}
}
}

0 comments on commit ede020c

Please sign in to comment.