From 9437f5b172d3109fc8469fd94af20ffc66f36d2b Mon Sep 17 00:00:00 2001 From: Marcel Koch Date: Tue, 30 Apr 2024 16:55:11 +0200 Subject: [PATCH] wip fixup rebase --- core/distributed/index_map.cpp | 54 +++++++++++++++++++ .../preconditioner/schwarz_ovlp.hpp | 25 ++++++--- include/ginkgo/core/distributed/index_map.hpp | 15 +++++- 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/core/distributed/index_map.cpp b/core/distributed/index_map.cpp index 864a3816324..0d46c0274c0 100644 --- a/core/distributed/index_map.cpp +++ b/core/distributed/index_map.cpp @@ -75,6 +75,47 @@ size_type index_map::get_global_size() const } +template +GlobalIndexType index_map::get_global( + const LocalIndexType& local_id, index_space is) const +{ + if (is == index_space::local) { + auto host_part = make_temporary_clone(exec_->get_master(), partition_); + auto host_local_ranges = + make_temporary_clone(exec_->get_master(), &local_ranges_); + auto ranges = host_part->get_range_bounds(); + auto range_start_idxs = host_part->get_range_starting_indices(); + + auto local_gid = static_cast(local_id); + + // find-last + auto range_id = static_cast(-1); + for (size_type i = 0; i < host_local_ranges->get_size(); ++i) { + if (range_start_idxs[host_local_ranges->get_const_data()[i]] <= + local_gid) { + range_id = host_local_ranges->get_const_data()[i]; + } + } + GKO_THROW_IF_INVALID(range_id != static_cast(-1), + "Index not part of local index space"); + + return ranges[range_id] + (local_gid - range_start_idxs[range_id]); + } + if (is == index_space::non_local) { + GKO_THROW_IF_INVALID(local_id < remote_global_idxs_.get_size(), + "Index not part of non-local index space"); + + auto host_remote_global_idxs = + make_const_array_view(exec_, remote_global_idxs_.get_size(), + remote_global_idxs_.get_flat_data()) + .copy_to_array(); + host_remote_global_idxs.set_executor(exec_->get_master()); + return host_remote_global_idxs.get_const_data()[local_id]; + } + GKO_NOT_IMPLEMENTED; +} + + template array index_map::get_local( const array& global_ids, index_space index_space_v) const @@ -96,6 +137,7 @@ index_map::index_map( const array& recv_connections) : exec_(std::move(exec)), partition_(std::move(partition)), + local_ranges_(exec_), rank_(rank), remote_target_ids_(exec_), remote_local_idxs_(exec_), @@ -113,6 +155,18 @@ index_map::index_map( remote_global_idxs_ = segmented_array::create_from_sizes( std::move(flat_remote_global_idxs), {exec_, remote_sizes.begin(), remote_sizes.end()}); + + auto host_part = make_temporary_clone(exec_->get_master(), partition_); + auto part_ids = host_part->get_part_ids(); + + std::vector host_local_ranges; + for (size_type i = 0; i < partition_->get_num_ranges(); ++i) { + if (part_ids[i] == rank_) { + host_local_ranges.push_back(i); + } + } + local_ranges_ = array(exec_, host_local_ranges.begin(), + host_local_ranges.end()); } diff --git a/core/distributed/preconditioner/schwarz_ovlp.hpp b/core/distributed/preconditioner/schwarz_ovlp.hpp index 040e7608a2a..090e10c53fa 100644 --- a/core/distributed/preconditioner/schwarz_ovlp.hpp +++ b/core/distributed/preconditioner/schwarz_ovlp.hpp @@ -39,7 +39,8 @@ std::vector> get_recv_rows( auto row_gatherer = mtx->get_row_gatherer(); auto& send_idxs = row_gatherer->get_row_idxs(); - auto host_send_idxs = make_temporary_clone(exec->get_master(), &send_idxs); + std::unique_ptr> host_send_idxs; + // = make_temporary_clone(exec->get_master(), &send_idxs); auto coll_comm = row_gatherer->get_collective_communicator(); @@ -61,7 +62,7 @@ std::vector> get_recv_rows( auto prev_size = global_md.nonzeros.size(); for (size_type i = send_offsets->get_const_data()[sid]; i < send_offsets->get_const_data()[sid + 1]; ++i) { - auto row = host_send_idxs->get_flat().get_const_data()[i]; + auto row = host_send_idxs->get_flat_data()[i]; for (LocalIndexType idx = host_local->get_const_row_ptrs()[row]; idx < host_local->get_const_row_ptrs()[row + 1]; ++idx) { @@ -123,8 +124,12 @@ std::vector> filter_non_relevant( std::copy_if(input.begin(), input.end(), std::back_inserter(result), [&](const auto& a) { auto is = index_space::combined; - return imap.is_within_index_space(a.row, is) && - imap.is_within_index_space(a.column, is); + return true; + // return + // imap.is_within_index_space(a.row, + // is) && + // imap.is_within_index_space(a.column, + // is); }); return result; } @@ -148,7 +153,7 @@ matrix_data combine_overlap( for (auto& e : non_local.nonzeros) { auto is = index_space::non_local; - e.column = imap.get_combined_local(e.column, is); + // e.column = imap.get_combined_local(e.column, is); } md local_recv_rows; @@ -156,9 +161,13 @@ matrix_data combine_overlap( std::back_inserter(local_recv_rows.nonzeros), [&](const auto& e) { auto is = index_space::combined; - return matrix_data_entry{ - imap.get_local(e.row, is), - imap.get_local(e.column, is), e.value}; + return typename md::nonzero_type{}; + // return + // matrix_data_entry{ + // imap.get_local(e.row, is), + // imap.get_local(e.column, + // is), e.value}; }); auto combined_size = imap.get_local_size() + imap.get_non_local_size(); diff --git a/include/ginkgo/core/distributed/index_map.hpp b/include/ginkgo/core/distributed/index_map.hpp index d769528b52f..315dca0897b 100644 --- a/include/ginkgo/core/distributed/index_map.hpp +++ b/include/ginkgo/core/distributed/index_map.hpp @@ -73,6 +73,18 @@ struct index_map { * \return the mapped local indices. Any global index that is not in the * specified index space is mapped to invalid_index. */ + GlobalIndexType get_global(const LocalIndexType& global_ids, + index_space is = index_space::combined) const; + + /** + * \brief Maps global indices to local indices + * + * \param global_ids the global indices to map + * \param is the index space in which the returned local indices are defined + * + * \return the mapped local indices. Any global index that is not in the + * specified index space is mapped to invalid_index. + */ array get_local( const array& global_ids, index_space index_space_v = index_space::combined) const; @@ -151,7 +163,7 @@ struct index_map { * * \return global partition used by the index map */ - std::shared_ptr get_partition() const + std::shared_ptr get_partition() const { return partition_; } @@ -172,6 +184,7 @@ struct index_map { private: std::shared_ptr exec_; std::shared_ptr partition_; + array local_ranges_; comm_index_type rank_; array remote_target_ids_;