Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Sep 19, 2023
1 parent 167be0b commit ce36c16
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions include/ginkgo/core/base/native_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ namespace gko {
*/
namespace layout {

namespace detail {

/**
* Always use mutable arrays, because it is assumed that the
* array mapper allows he conversion `am<T> -> am<const T>`
*/
template <typename T>
auto make_mutable_array(std::shared_ptr<const Executor> exec, size_type size,
T* data)
{
using U = std::remove_cv_t<T>;
return gko::make_array_view(std::move(exec), size, const_cast<U*>(data));
}

} // namespace detail


/**
* A view of gko::device_matrix_data.
Expand All @@ -77,10 +93,12 @@ struct device_matrix_data {
size_type num_elems, IndexType* row_idxs,
IndexType* col_idxs, ValueType* values)
{
return {
array_mapper::map(gko::make_array_view(exec, num_elems, row_idxs)),
array_mapper::map(gko::make_array_view(exec, num_elems, col_idxs)),
array_mapper::map(gko::make_array_view(exec, num_elems, values))};
return {array_mapper::map(
detail::make_mutable_array(exec, num_elems, row_idxs)),
array_mapper::map(
detail::make_mutable_array(exec, num_elems, col_idxs)),
array_mapper::map(
detail::make_mutable_array(exec, num_elems, values))};
}

index_array row_idxs;
Expand Down Expand Up @@ -176,14 +194,14 @@ template <typename ValueType, typename IndexType, typename array_mapper,
typename dense_mapper>
struct native<const device_matrix_data<ValueType, IndexType>, array_mapper,
dense_mapper> {
using type = layout::device_matrix_data<ValueType, IndexType, array_mapper>;
using type = layout::device_matrix_data<const ValueType, const IndexType,
array_mapper>;

static type map(const device_matrix_data<ValueType, IndexType>& md)
{
return type::map(md.get_executor(), md.get_num_elems(),
const_cast<IndexType>(md.get_const_row_idxs()),
const_cast<IndexType>(md.get_const_col_idxs()),
const_cast<ValueType>(md.get_values()));
md.get_const_row_idxs(), md.get_const_col_idxs(),
md.get_values());
}
};

Expand Down

0 comments on commit ce36c16

Please sign in to comment.