From 6636f9d0952b15b002c5ff80120cad9a2ff347f6 Mon Sep 17 00:00:00 2001 From: Benjamin Brock Date: Wed, 29 Nov 2023 14:06:02 -0800 Subject: [PATCH] Implement `resize` in shp (#662) --- examples/shp/vector_example.cpp | 8 ++++++-- include/dr/shp/distributed_vector.hpp | 9 +++++++++ test/gtest/shp/containers.cpp | 21 +++++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/examples/shp/vector_example.cpp b/examples/shp/vector_example.cpp index c106cefc14..a7eefdf2aa 100644 --- a/examples/shp/vector_example.cpp +++ b/examples/shp/vector_example.cpp @@ -4,8 +4,6 @@ #include -template void iter(Iter) {} - int main(int argc, char **argv) { printf("Creating NUMA devices...\n"); auto devices = dr::shp::get_numa_devices(sycl::default_selector_v); @@ -51,5 +49,11 @@ int main(int argc, char **argv) { dr::shp::print_range(local_vec, "local vec after copy"); + v.resize(200); + dr::shp::print_range(v, "resized to 200"); + + v.resize(50); + dr::shp::print_range(v, "resized to 50"); + return 0; } diff --git a/include/dr/shp/distributed_vector.hpp b/include/dr/shp/distributed_vector.hpp index 3d6f0061f4..823862c21b 100644 --- a/include/dr/shp/distributed_vector.hpp +++ b/include/dr/shp/distributed_vector.hpp @@ -198,6 +198,15 @@ struct distributed_vector { : begin(); } + void resize(size_type count, const value_type &value) { + distributed_vector other(count, value); + std::size_t copy_size = std::min(other.size(), size()); + dr::shp::copy(begin(), begin() + copy_size, other.begin()); + *this = std::move(other); + } + + void resize(size_type count) { resize(count, value_type{}); } + private: std::vector segments_; std::size_t capacity_ = 0; diff --git a/test/gtest/shp/containers.cpp b/test/gtest/shp/containers.cpp index e34f9090ae..f32a4bc9cf 100644 --- a/test/gtest/shp/containers.cpp +++ b/test/gtest/shp/containers.cpp @@ -84,6 +84,27 @@ TYPED_TEST(DistributedVectorTest, Iterator) { EXPECT_TRUE(std::equal(v_a.begin(), v_a.end(), dv_a.begin())); } +TYPED_TEST(DistributedVectorTest, Resize) { + std::size_t size = 100; + typename TestFixture::DistVec dv(size); + dr::shp::iota(dv.begin(), dv.end(), 20); + + typename TestFixture::LocalVec v(size); + std::iota(v.begin(), v.end(), 20); + + dv.resize(size * 2); + v.resize(size * 2); + EXPECT_EQ(dv, v); + + dv.resize(size); + v.resize(size); + EXPECT_EQ(dv, v); + + dv.resize(size * 2, 12); + v.resize(size * 2, 12); + EXPECT_EQ(dv, v); +} + template class DeviceVectorTest : public testing::Test { public: using DeviceVec = dr::shp::device_vector;