Skip to content

Commit

Permalink
wip fixup rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 8, 2024
1 parent 5ac4c62 commit d14576a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
54 changes: 54 additions & 0 deletions core/distributed/index_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,47 @@ size_type index_map<LocalIndexType, GlobalIndexType>::get_global_size() const
}


template <typename LocalIndexType, typename GlobalIndexType>
GlobalIndexType index_map<LocalIndexType, GlobalIndexType>::get_global(
const LocalIndexType& local_id, index_space is) const
{
if (is == index_space::local) {
auto host_part = make_temporary_clone(exec_->get_master(), partition_);
auto host_local_ranges =
make_temporary_clone(exec_->get_master(), &local_ranges_);
auto ranges = host_part->get_range_bounds();
auto range_start_idxs = host_part->get_range_starting_indices();

auto local_gid = static_cast<GlobalIndexType>(local_id);

// find-last
auto range_id = static_cast<size_type>(-1);
for (size_type i = 0; i < host_local_ranges->get_size(); ++i) {
if (range_start_idxs[host_local_ranges->get_const_data()[i]] <=
local_gid) {
range_id = host_local_ranges->get_const_data()[i];
}
}
GKO_THROW_IF_INVALID(range_id != static_cast<size_type>(-1),
"Index not part of local index space");

return ranges[range_id] + (local_gid - range_start_idxs[range_id]);
}
if (is == index_space::non_local) {
GKO_THROW_IF_INVALID(local_id < remote_global_idxs_.get_size(),
"Index not part of non-local index space");

auto host_remote_global_idxs =
make_const_array_view(exec_, remote_global_idxs_.get_size(),
remote_global_idxs_.get_const_flat_data())
.copy_to_array();
host_remote_global_idxs.set_executor(exec_->get_master());
return host_remote_global_idxs.get_const_data()[local_id];
}
GKO_NOT_IMPLEMENTED;
}


template <typename LocalIndexType, typename GlobalIndexType>
array<LocalIndexType> index_map<LocalIndexType, GlobalIndexType>::map_to_local(
const array<GlobalIndexType>& global_ids, index_space index_space_v) const
Expand All @@ -97,6 +138,7 @@ index_map<LocalIndexType, GlobalIndexType>::index_map(
const array<GlobalIndexType>& recv_connections)
: exec_(std::move(exec)),
partition_(std::move(partition)),
local_ranges_(exec_),
rank_(rank),
remote_target_ids_(exec_),
remote_local_idxs_(exec_),
Expand All @@ -112,6 +154,18 @@ index_map<LocalIndexType, GlobalIndexType>::index_map(
std::move(flat_remote_local_idxs), remote_sizes);
remote_global_idxs_ = segmented_array<GlobalIndexType>::create_from_sizes(
std::move(flat_remote_global_idxs), remote_sizes);

auto host_part = make_temporary_clone(exec_->get_master(), partition_);
auto part_ids = host_part->get_part_ids();

std::vector<size_type> host_local_ranges;
for (size_type i = 0; i < partition_->get_num_ranges(); ++i) {
if (part_ids[i] == rank_) {
host_local_ranges.push_back(i);
}
}
local_ranges_ = array<size_type>(exec_, host_local_ranges.begin(),
host_local_ranges.end());
}


Expand Down
24 changes: 17 additions & 7 deletions core/distributed/preconditioner/schwarz_ovlp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ std::vector<matrix_data_entry<ValueType, GlobalIndexType>> get_recv_rows(

auto row_gatherer = mtx->get_row_gatherer();
auto& send_idxs = row_gatherer->get_row_idxs();
auto host_send_idxs = make_temporary_clone(exec->get_master(), &send_idxs);
std::unique_ptr<segmented_array<LocalIndexType>> host_send_idxs;
// = make_temporary_clone(exec->get_master(), &send_idxs);

auto coll_comm = row_gatherer->get_collective_communicator();

Expand Down Expand Up @@ -123,8 +124,12 @@ std::vector<matrix_data_entry<ValueType, GlobalIndexType>> filter_non_relevant(
std::copy_if(input.begin(), input.end(), std::back_inserter(result),
[&](const auto& a) {
auto is = index_space::combined;
return imap.is_within_index_space(a.row, is) &&
imap.is_within_index_space(a.column, is);
return true;
// return
// imap.is_within_index_space(a.row,
// is) &&
// imap.is_within_index_space(a.column,
// is);
});
return result;
}
Expand All @@ -148,17 +153,22 @@ matrix_data<ValueType, LocalIndexType> combine_overlap(

for (auto& e : non_local.nonzeros) {
auto is = index_space::non_local;
e.column = imap.get_combined_local(e.column, is);
// e.column = imap.get_combined_local(e.column, is);
}

md local_recv_rows;
std::transform(recv_rows.begin(), recv_rows.end(),
std::back_inserter(local_recv_rows.nonzeros),
[&](const auto& e) {
auto is = index_space::combined;
return matrix_data_entry<ValueType, LocalIndexType>{
imap.get_local(e.row, is),
imap.get_local(e.column, is), e.value};
return typename md::nonzero_type{};
// return
// matrix_data_entry<ValueType,
// LocalIndexType>{
// imap.map_to_local(e.row,
// is),
// imap.map_to_local(e.column,
// is), e.value};
});

auto combined_size = imap.get_local_size() + imap.get_non_local_size();
Expand Down
15 changes: 14 additions & 1 deletion include/ginkgo/core/distributed/index_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ struct index_map {
* \return the mapped local indices. Any global index that is not in the
* specified index space is mapped to invalid_index.
*/
GlobalIndexType get_global(const LocalIndexType& global_ids,
index_space is = index_space::combined) const;

/**
* \brief Maps global indices to local indices
*
* \param global_ids the global indices to map
* \param is the index space in which the returned local indices are defined
*
* \return the mapped local indices. Any global index that is not in the
* specified index space is mapped to invalid_index.
*/
array<LocalIndexType> map_to_local(
const array<GlobalIndexType>& global_ids,
index_space index_space_v = index_space::combined) const;
Expand Down Expand Up @@ -156,7 +168,7 @@ struct index_map {
*
* \return global partition used by the index map
*/
std::shared_ptr<const part_type> get_partition() const
std::shared_ptr<const partition_type> get_partition() const
{
return partition_;
}
Expand All @@ -177,6 +189,7 @@ struct index_map {
private:
std::shared_ptr<const Executor> exec_;
std::shared_ptr<const partition_type> partition_;
array<size_type> local_ranges_;
comm_index_type rank_;

array<comm_index_type> remote_target_ids_;
Expand Down

0 comments on commit d14576a

Please sign in to comment.