Skip to content

Commit

Permalink
Bindings for hpx/algorithm.hpp (#7)
Browse files Browse the repository at this point in the history
* added hpx::copy

* added hpx::copy_n

* added hpx::copy_if

* added hpx::count, hpx::count_if

* rerun tests

* added hpx::ends_with

* added hpx::equal

* added hpx::fill but it only works for 1D vector

* added hpx::find

* added hpx::sort

* added hpx::sort along with comparator closure as an argument

* added hpx::merge

* added hpx::partial_sort

* moved wrappers from tests to main

* removed redundant copy from copy, sort

* removed redundant copy from copy_n

* removed redundant copy from find, fill, sort_comp

* removed redundant copy from merge, partial sort

* removed redundant copy from copy_if

* rerun tests

---------

Signed-off-by: Dikshant <[email protected]>
  • Loading branch information
pingu-73 authored Sep 1, 2024
1 parent 66a30c1 commit 4bb45f8
Show file tree
Hide file tree
Showing 3 changed files with 580 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "hpx-rs"
version = "0.1.0"
authors = ["Shreyas Atre <[email protected]>", "Dikshant <[email protected]"]
authors = ["Shreyas Atre <[email protected]>", "Dikshant <[email protected]>"]
edition = "2021"
readme = "README.md"
repository = "https://github.com/STEllAR-GROUP/hpx-rs"
Expand Down
135 changes: 123 additions & 12 deletions hpx-sys/include/wrapper.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
#pragma once

#include <hpx/hpx_init.hpp>
#include <hpx/algorithm.hpp>
#include <iostream>
#include <cstdint>
#include <vector>

#include "rust/cxx.h"


/*inline std::int32_t start() { return hpx::start(nullptr, 0, nullptr); }*/

/*inline std::int32_t start(rust::Fn<int(int, char **)> rust_fn, int argc, char **argv) {*/
/* return hpx::start(*/
/* [&](int argc, char **argv) {*/
/* return rust_fn(argc, argv);*/
/* },*/
/* argc, argv);*/
/*}*/

inline std::int32_t init(rust::Fn<int(int, char **)> rust_fn, int argc, char **argv) {
return hpx::init(
[&](int argc, char **argv) {
Expand All @@ -44,4 +34,125 @@ inline std::int32_t disconnect_with_timeout(double shutdown_timeout, double loca

inline std::int32_t finalize() { return hpx::finalize(); }

/*inline std::int32_t stop() { return hpx::stop(); }*/
inline void hpx_copy(rust::Slice<const int32_t> src, rust::Slice<int32_t> dest) {
hpx::copy(hpx::execution::par, src.begin(), src.end(), dest.begin());
}

inline void hpx_copy_n(rust::Slice<const int32_t> src, size_t count, rust::Slice<int32_t> dest) {
hpx::copy_n(hpx::execution::par, src.begin(), count, dest.begin());
}

inline void hpx_copy_if(const rust::Vec<int32_t>& src, rust::Vec<int32_t>& dest,
rust::Fn<bool(int32_t)> pred) {
std::vector<int32_t> cpp_dest(src.size());

auto result = hpx::copy_if(hpx::execution::par,
src.begin(), src.end(),
cpp_dest.begin(),
[&](int32_t value) { return pred(value); });

cpp_dest.resize(std::distance(cpp_dest.begin(), result));

dest.clear();
dest.reserve(cpp_dest.size());
for (const auto& item : cpp_dest) {
dest.push_back(item);
}
}

inline std::int64_t hpx_count(const rust::Vec<int32_t>& vec, int32_t value) {
return hpx::count(hpx::execution::par, vec.begin(), vec.end(), value);
}


inline int64_t hpx_count_if(const rust::Vec<int32_t>& vec, rust::Fn<bool(int32_t)> pred) {
std::vector<int32_t> cpp_vec(vec.begin(), vec.end());

auto result = hpx::count_if(hpx::execution::par,
cpp_vec.begin(),
cpp_vec.end(),
[&](int32_t value) { return pred(value); });

return static_cast<int64_t>(result);
}

inline bool hpx_ends_with(rust::Slice<const int32_t> src,
rust::Slice<const int32_t> dest) {
return hpx::ends_with(hpx::execution::par,
src.begin(), src.end(),
dest.begin(), dest.end(),
std::equal_to<int32_t>());
}

inline bool hpx_equal(rust::Slice<const int32_t> src, rust::Slice<const int32_t> dest) {
return hpx::equal(
hpx::execution::par,
src.begin(), src.end(),
dest.begin(), dest.end()
);
}

inline void hpx_fill(rust::Slice<int32_t> src, int32_t value) {
hpx::fill(hpx::execution::par, src.begin(), src.end(), value);
}

inline int64_t hpx_find(rust::Slice<const int32_t> src, int32_t value) {
auto result = hpx::find(hpx::execution::par,
src.begin(),
src.end(),
value);

if (result != src.end()) {
return static_cast<int64_t>(std::distance(src.begin(), result));
}
return -1;
}

inline void hpx_sort(rust::Slice<int32_t> src) {
hpx::sort(hpx::execution::par, src.begin(), src.end());
}

inline void hpx_sort_comp(rust::Vec<int32_t>& src, rust::Fn<bool(int32_t, int32_t)> comp) {
hpx::sort(hpx::execution::par, src.begin(), src.end(),
[&](int32_t a, int32_t b) { return comp(a, b); });
}

inline void hpx_merge(rust::Slice<const int32_t> src1,
rust::Slice<const int32_t> src2,
rust::Vec<int32_t>& dest) {
dest.clear();
dest.reserve(src1.size() + src2.size());

for (size_t i = 0; i < src1.size() + src2.size(); ++i) {
dest.push_back(0);
}

hpx::merge(hpx::execution::par,
src1.begin(), src1.end(),
src2.begin(), src2.end(),
dest.begin());
}

inline void hpx_partial_sort(rust::Vec<int32_t>& src, size_t last) {
if (last > src.size()) {
last = src.size();
}

hpx::partial_sort(hpx::execution::par,
src.begin(),
src.begin() + last,
src.end());
}

inline void hpx_partial_sort_comp(rust::Vec<int32_t>& src, size_t last,
rust::Fn<bool(int32_t, int32_t)> comp) {
if (last > src.size()) {
last = src.size();
}

hpx::partial_sort(hpx::execution::par,
src.begin(),
src.begin() + last,
src.end(),
[&](int32_t a, int32_t b) { return comp(a, b); });
}
Loading

0 comments on commit 4bb45f8

Please sign in to comment.