Skip to content

Commit

Permalink
adds rebind to the allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanmrichard committed Dec 10, 2024
1 parent 0316e9c commit 3beda27
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
56 changes: 55 additions & 1 deletion include/tensorwrapper/allocator/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once
#include <tensorwrapper/allocator/replicated.hpp>
#include <tensorwrapper/buffer/eigen.hpp>
#include <tensorwrapper/buffer/buffer_fwd.hpp>

namespace tensorwrapper::allocator {

Expand Down Expand Up @@ -44,12 +44,19 @@ class Eigen : public Replicated {
using my_base_type::buffer_base_pointer;
using my_base_type::buffer_base_reference;
using my_base_type::const_base_reference;
using my_base_type::const_buffer_base_reference;
using my_base_type::layout_pointer;
using my_base_type::runtime_view_type;

/// Type of a buffer containing an Eigen tensor
using eigen_buffer_type = buffer::Eigen<FloatType, Rank>;

/// Type of a mutable reference to an object of type eigen_buffer_type
using eigen_buffer_reference = eigen_buffer_type&;

/// Type of a read-only reference to an object of type eigen_buffer_type
using const_eigen_buffer_reference = const eigen_buffer_type&;

/// Type of a pointer to an eigen_buffer_type object
using eigen_buffer_pointer = std::unique_ptr<eigen_buffer_type>;

Expand Down Expand Up @@ -141,6 +148,53 @@ class Eigen : public Replicated {
return pbuffer;
}

/** @brief Determines if @p buffer can be rebound as an Eigen buffer.
*
* Rebinding a buffer allows the same memory to be viewed as a (possibly)
* different type of buffer.
*
* @param[in] buffer The tensor we are attempting to rebind.
*
* @return True if @p buffer can be rebound to the type of buffer
* associated with this allocator and false otherwise.
*
* @throw None No throw guarantee
*/
static bool can_rebind(const_buffer_base_reference buffer);

/** @brief Rebinds a buffer to the same type as *this.
*
* This method will convert @p buffer into a buffer which could have been
* allocated by *this. If @p buffer was allocated as such a buffer already,
* then this method is simply a downcast.
*
* @param[in] buffer The buffer to rebind.
*
* @return A mutable reference to @p buffer viewed as a buffer that could
* have been allocated by *this.
*
* @throw std::runtime_error if can_rebind(buffer) is false. Strong throw
* guarantee.
*/
static eigen_buffer_reference rebind(buffer_base_reference buffer);

/** @brief Rebinds a buffer to the same type as *this.
*
* This method is the same as the non-const version except that the result
* is read-only. See the description for the non-const version for more
* details.
*
* @param[in] buffer The buffer to rebind.
*
* @return A read-only reference to @p buffer viewed as if it was
* allocated by *this.
*
* @throw std::runtime_error if can_rebind(buffer) is false. Strong throw
* guarantee.
*/
static const_eigen_buffer_reference rebind(
const_buffer_base_reference buffer);

/** @brief Is *this value equal to @p rhs?
*
* @tparam FloatType2 The numerical type @p rhs uses for its elements.
Expand Down
22 changes: 22 additions & 0 deletions src/tensorwrapper/allocator/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <tensorwrapper/allocator/eigen.hpp>
#include <tensorwrapper/buffer/eigen.hpp>
#include <tensorwrapper/detail_/unique_ptr_utilities.hpp>
#include <tensorwrapper/shape/smooth.hpp>

Expand Down Expand Up @@ -45,6 +46,27 @@ typename EIGEN::eigen_buffer_pointer EIGEN::allocate(
*playout);
}

TPARAMS
bool EIGEN::can_rebind(const_buffer_base_reference buffer) {
auto pbuffer = dynamic_cast<const eigen_buffer_type*>(&buffer);
return pbuffer != nullptr;
}

TPARAMS
typename EIGEN::eigen_buffer_reference EIGEN::rebind(
buffer_base_reference buffer) {
if(can_rebind(buffer)) return static_cast<eigen_buffer_reference>(buffer);
throw std::runtime_error("Can not rebind buffer");
}

TPARAMS
typename EIGEN::const_eigen_buffer_reference EIGEN::rebind(
const_buffer_base_reference buffer) {
if(can_rebind(buffer))
return dynamic_cast<const_eigen_buffer_reference>(buffer);
throw std::runtime_error("Can not rebind buffer");
}

#define ALLOCATE_CONDITION(RANK) \
if(rank == RANK) return std::make_unique<Eigen<FloatType, RANK>>(rv)

Expand Down

0 comments on commit 3beda27

Please sign in to comment.