Skip to content

Commit

Permalink
fixup! read copies into arrays directly
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Dec 4, 2023
1 parent f3e2286 commit 3e539bd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
17 changes: 16 additions & 1 deletion core/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <ginkgo/core/matrix/dense.hpp>


#include "core/base/device_matrix_data_kernels.hpp"
#include "core/components/absolute_array_kernels.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/components/format_conversion_kernels.hpp"
Expand All @@ -43,6 +44,7 @@ GKO_REGISTER_OPERATION(inplace_absolute_array,
components::inplace_absolute_array);
GKO_REGISTER_OPERATION(outplace_absolute_array,
components::outplace_absolute_array);
GKO_REGISTER_OPERATION(aos_to_soa, components::aos_to_soa);


} // anonymous namespace
Expand Down Expand Up @@ -180,7 +182,20 @@ void Coo<ValueType, IndexType>::resize(dim<2> new_size, size_type nnz)
template <typename ValueType, typename IndexType>
void Coo<ValueType, IndexType>::read(const mat_data& data)
{
this->read(device_mat_data::create_from_host(this->get_executor(), data));
auto size = data.size;
auto exec = this->get_executor();
this->set_size(size);
row_idxs_.resize_and_reset(data.nonzeros.size());
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
device_mat_data view{exec, size, row_idxs_.as_view(), col_idxs_.as_view(),
values_.as_view()};
const auto host_data =
make_array_view(exec->get_master(), data.nonzeros.size(),
const_cast<matrix_data_entry<ValueType, IndexType>*>(
data.nonzeros.data()));
exec->run(
coo::make_aos_to_soa(*make_temporary_clone(exec, &host_data), view));
}


Expand Down
20 changes: 19 additions & 1 deletion core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ GKO_REGISTER_OPERATION(inv_scale, csr::inv_scale);
GKO_REGISTER_OPERATION(add_scaled_identity, csr::add_scaled_identity);
GKO_REGISTER_OPERATION(check_diagonal_entries,
csr::check_diagonal_entries_exist);
GKO_REGISTER_OPERATION(aos_to_soa, components::aos_to_soa);


} // anonymous namespace
Expand Down Expand Up @@ -423,7 +424,24 @@ void Csr<ValueType, IndexType>::move_to(Fbcsr<ValueType, IndexType>* result)
template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::read(const mat_data& data)
{
this->read(device_mat_data::create_from_host(this->get_executor(), data));
auto size = data.size;
auto exec = this->get_executor();
row_ptrs_.resize_and_reset(size[0] + 1);
col_idxs_.resize_and_reset(data.nonzeros.size());
values_.resize_and_reset(data.nonzeros.size());
device_mat_data view{exec, size,
array<IndexType>{exec, data.nonzeros.size()},
col_idxs_.as_view(), values_.as_view()};
const auto host_data =
make_array_view(exec->get_master(), data.nonzeros.size(),
const_cast<matrix_data_entry<ValueType, IndexType>*>(
data.nonzeros.data()));
exec->run(
csr::make_aos_to_soa(*make_temporary_clone(exec, &host_data), view));
exec->run(csr::make_convert_idxs_to_ptrs(view.get_const_row_idxs(),
view.get_num_elems(), size[0],
get_row_ptrs()));
this->make_srow();
}


Expand Down

0 comments on commit 3e539bd

Please sign in to comment.