Skip to content

Commit

Permalink
Implement resize in shp (#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenBrock authored Nov 29, 2023
1 parent d28fc3b commit 6636f9d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
8 changes: 6 additions & 2 deletions examples/shp/vector_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include <dr/shp.hpp>

template <dr::distributed_iterator Iter> void iter(Iter) {}

int main(int argc, char **argv) {
printf("Creating NUMA devices...\n");
auto devices = dr::shp::get_numa_devices(sycl::default_selector_v);
Expand Down Expand Up @@ -51,5 +49,11 @@ int main(int argc, char **argv) {

dr::shp::print_range(local_vec, "local vec after copy");

v.resize(200);
dr::shp::print_range(v, "resized to 200");

v.resize(50);
dr::shp::print_range(v, "resized to 50");

return 0;
}
9 changes: 9 additions & 0 deletions include/dr/shp/distributed_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ struct distributed_vector {
: begin();
}

void resize(size_type count, const value_type &value) {
distributed_vector<T, Allocator> other(count, value);
std::size_t copy_size = std::min(other.size(), size());
dr::shp::copy(begin(), begin() + copy_size, other.begin());
*this = std::move(other);
}

void resize(size_type count) { resize(count, value_type{}); }

private:
std::vector<segment_type> segments_;
std::size_t capacity_ = 0;
Expand Down
21 changes: 21 additions & 0 deletions test/gtest/shp/containers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ TYPED_TEST(DistributedVectorTest, Iterator) {
EXPECT_TRUE(std::equal(v_a.begin(), v_a.end(), dv_a.begin()));
}

TYPED_TEST(DistributedVectorTest, Resize) {
std::size_t size = 100;
typename TestFixture::DistVec dv(size);
dr::shp::iota(dv.begin(), dv.end(), 20);

typename TestFixture::LocalVec v(size);
std::iota(v.begin(), v.end(), 20);

dv.resize(size * 2);
v.resize(size * 2);
EXPECT_EQ(dv, v);

dv.resize(size);
v.resize(size);
EXPECT_EQ(dv, v);

dv.resize(size * 2, 12);
v.resize(size * 2, 12);
EXPECT_EQ(dv, v);
}

template <typename AllocT> class DeviceVectorTest : public testing::Test {
public:
using DeviceVec = dr::shp::device_vector<typename AllocT::value_type, AllocT>;
Expand Down

0 comments on commit 6636f9d

Please sign in to comment.