Skip to content

Commit

Permalink
Adds pairwise_point_polygon_distance benchmark (#1131)
Browse files Browse the repository at this point in the history
This PR separates the `pairwise_point_polygon_distance` benchmark portion of PR #1002. While that PR is only left for nvtx3 experiments.

# Original PR description:

This PR adds pairwise point polygon distance benchmark. Depends on #998

Point-polygon distance performance can be affected by many factors, because the geometry is complex in nature. I benchmarked these questions:
1. How does the algorithm scales with simple multipolygons?
2. How does it scales with complex multipolygons?

## How does the algorithm scales with simple multipolygons?
The benchmark uses the most simple multipolygon, 3 sides per polygon, 0 hole and 1 polygon per multipolygon.

Float32
| Num   multipolygon | Throughput   (#multipolygons / s) |
| --- | --- |
| 1 | 28060.32971 |
| 100 | 2552687.469 |
| 10000 | 186044781 |
| 1000000 | 1047783101 |
| 100000000 | 929537385.2 | 

Float64
| Num   multipolygon | Throughput   (#multipolygons / s) |
| --- | --- |
| 1 | 28296.94817 |
| 100 | 2491541.218 |
| 10000 | 179379919.5 |
| 1000000 | 854678939.9 |
| 100000000 | 783364410.7 |

![image](https://user-images.githubusercontent.com/13521008/226502300-c3273d80-5f9f-4d53-b961-a24e64216e9b.png)

The chart shows that with simple polygons and simple multipoint (1 point per multipoint), the algorithm scales pretty nicely. Throughput is maxed out at near 1M pairs.

## How does the algorithm scales with complex multipolygons?

The benchmark uses a complex multipolygon, 100 edges per ring, 10 holes per polygon and 3 polygons per multipolygon.

float32
Num   multipolygon | Throughput   (#multipolygons / s)
-- | --
1000 | 158713.2377
10000 | 345694.2642
100000 | 382849.058

float64
Num   multipolygon | Throughput   (#multipolygons / s)
-- | --
1000 | 148727.1246
10000 | 353141.9758
100000 | 386007.3016

![image](https://user-images.githubusercontent.com/13521008/226502732-0d116db7-6257-4dec-b170-c42b30df9cea.png)

The algorithm reaches max throughput at near 10K pairs. About 100X lower than the simple multipolygon example.

Authors:
  - Michael Wang (https://github.com/isVoid)
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Mark Harris (https://github.com/harrism)

URL: #1131
  • Loading branch information
isVoid authored May 23, 2023
1 parent c6ecbc0 commit 7f42eeb
Show file tree
Hide file tree
Showing 16 changed files with 269 additions and 42 deletions.
1 change: 1 addition & 0 deletions cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ ConfigureBench(HAUSDORFF_BENCH
distance/hausdorff_benchmark.cpp)

ConfigureNVBench(DISTANCES_BENCH
distance/pairwise_point_polygon_distance.cu
distance/pairwise_linestring_distance.cu)

ConfigureNVBench(QUADTREE_ON_POINTS_BENCH
Expand Down
134 changes: 134 additions & 0 deletions cpp/benchmarks/distance/pairwise_point_polygon_distance.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmarks/fixture/rmm_pool_raii.hpp>
#include <nvbench/nvbench.cuh>

#include <cuspatial_test/geometry_generator.cuh>

#include <cuspatial/distance.cuh>
#include <cuspatial/geometry/vec_2d.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

using namespace cuspatial;
using namespace cuspatial::test;

template <typename T>
void pairwise_point_polygon_distance_benchmark(nvbench::state& state, nvbench::type_list<T>)
{
// TODO: to be replaced by nvbench fixture once it's ready
cuspatial::rmm_pool_raii rmm_pool;
rmm::cuda_stream_view stream{rmm::cuda_stream_default};

auto const num_pairs{static_cast<std::size_t>(state.get_int64("num_pairs"))};

auto const num_polygons_per_multipolygon{
static_cast<std::size_t>(state.get_int64("num_polygons_per_multipolygon"))};
auto const num_holes_per_polygon{
static_cast<std::size_t>(state.get_int64("num_holes_per_polygon"))};
auto const num_edges_per_ring{static_cast<std::size_t>(state.get_int64("num_edges_per_ring"))};

auto const num_points_per_multipoint{
static_cast<std::size_t>(state.get_int64("num_points_per_multipoint"))};

auto mpoly_generator_param = multipolygon_generator_parameter<T>{
num_pairs, num_polygons_per_multipolygon, num_holes_per_polygon, num_edges_per_ring};

auto mpoint_generator_param = multipoint_generator_parameter<T>{
num_pairs, num_points_per_multipoint, vec_2d<T>{-1, -1}, vec_2d<T>{0, 0}};

auto multipolygons = generate_multipolygon_array<T>(mpoly_generator_param, stream);
auto multipoints = generate_multipoint_array<T>(mpoint_generator_param, stream);

auto distances = rmm::device_vector<T>(num_pairs);
auto out_it = distances.begin();

auto mpoly_view = multipolygons.range();
auto mpoint_view = multipoints.range();

state.add_element_count(num_pairs, "NumPairs");
state.add_element_count(mpoly_generator_param.num_polygons(), "NumPolygons");
state.add_element_count(mpoly_generator_param.num_rings(), "NumRings");
state.add_element_count(mpoly_generator_param.num_coords(), "NumPoints (in mpoly)");
state.add_element_count(static_cast<std::size_t>(mpoly_generator_param.num_coords() *
mpoly_generator_param.num_rings() *
mpoly_generator_param.num_polygons()),
"Multipolygon Complexity");
state.add_element_count(mpoint_generator_param.num_points(), "NumPoints (in multipoints)");

state.add_global_memory_reads<T>(
mpoly_generator_param.num_coords() + mpoint_generator_param.num_points(),
"CoordinatesReadSize");
state.add_global_memory_reads<std::size_t>(
(mpoly_generator_param.num_rings() + 1) + (mpoly_generator_param.num_polygons() + 1) +
(mpoly_generator_param.num_multipolygons + 1) + (mpoint_generator_param.num_multipoints + 1),
"OffsetsDataSize");

state.add_global_memory_writes<T>(num_pairs);

state.exec(nvbench::exec_tag::sync,
[&mpoly_view, &mpoint_view, &out_it, &stream](nvbench::launch& launch) {
pairwise_point_polygon_distance(mpoint_view, mpoly_view, out_it, stream);
});
}

using floating_point_types = nvbench::type_list<float, double>;

// Benchmark scalability with simple multipolygon (3 sides, 0 hole, 1 poly)
NVBENCH_BENCH_TYPES(pairwise_point_polygon_distance_benchmark,
NVBENCH_TYPE_AXES(floating_point_types))
.set_type_axes_names({"CoordsType"})
.add_int64_axis("num_pairs", {1, 1'00, 10'000, 1'000'000, 100'000'000})
.add_int64_axis("num_polygons_per_multipolygon", {1})
.add_int64_axis("num_holes_per_polygon", {0})
.add_int64_axis("num_edges_per_ring", {3})
.add_int64_axis("num_points_per_multipoint", {1})
.set_name("point_polygon_distance_benchmark_simple_polygon");

// Benchmark scalability with complex multipolygon (100 sides, 10 holes, 3 polys)
NVBENCH_BENCH_TYPES(pairwise_point_polygon_distance_benchmark,
NVBENCH_TYPE_AXES(floating_point_types))
.set_type_axes_names({"CoordsType"})
.add_int64_axis("num_pairs", {1'000, 10'000, 100'000, 1'000'000})
.add_int64_axis("num_polygons_per_multipolygon", {2})
.add_int64_axis("num_holes_per_polygon", {3})
.add_int64_axis("num_edges_per_ring", {50})
.add_int64_axis("num_points_per_multipoint", {1})
.set_name("point_polygon_distance_benchmark_complex_polygon");

// // Benchmark impact of rings (100K pairs, 1 polygon, 3 sides)
NVBENCH_BENCH_TYPES(pairwise_point_polygon_distance_benchmark,
NVBENCH_TYPE_AXES(floating_point_types))
.set_type_axes_names({"CoordsType"})
.add_int64_axis("num_pairs", {10'000})
.add_int64_axis("num_polygons_per_multipolygon", {1})
.add_int64_axis("num_holes_per_polygon", {0, 10, 100, 1000})
.add_int64_axis("num_edges_per_ring", {3})
.add_int64_axis("num_points_per_multipoint", {1})
.set_name("point_polygon_distance_benchmark_ring_numbers");

// Benchmark impact of rings (1M pairs, 1 polygon, 0 holes, 3 sides)
NVBENCH_BENCH_TYPES(pairwise_point_polygon_distance_benchmark,
NVBENCH_TYPE_AXES(floating_point_types))
.set_type_axes_names({"CoordsType"})
.add_int64_axis("num_pairs", {100})
.add_int64_axis("num_polygons_per_multipolygon", {1})
.add_int64_axis("num_holes_per_polygon", {0})
.add_int64_axis("num_edges_per_ring", {3})
.add_int64_axis("num_points_per_multipoint", {50, 5'00, 5'000, 50'000, 500'000})
.set_name("point_polygon_distance_benchmark_points_in_multipoint");
11 changes: 3 additions & 8 deletions cpp/include/cuspatial/detail/nvtx/ranges.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,17 +20,12 @@

namespace cuspatial {
/**
* @brief Tag type for libcudf's NVTX domain.
* @brief Tag type for libcuspatial's NVTX domain.
*/
struct libcuspatial_domain {
static constexpr char const* name{"libcuspatial"}; ///< Name of the libcudf domain
static constexpr char const* name{"libcuspatial"}; ///< Name of the libcuspatial domain
};

/**
* @brief Alias for an NVTX range in the libcudf domain.
*/
using thread_range = ::nvtx3::domain_thread_range<libcudf_domain>;

} // namespace cuspatial

/**
Expand Down
56 changes: 56 additions & 0 deletions cpp/include/cuspatial_test/geometry_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cuspatial_test/random.cuh>
#include <cuspatial_test/vector_factories.cuh>

#include <cuspatial/cuda_utils.hpp>
Expand Down Expand Up @@ -251,5 +252,60 @@ auto generate_multipolygon_array(multipolygon_generator_parameter<T> params,
std::move(coordinates));
}

/**
* @brief Struct to store the parameters of the multipoint aray
*
* @tparam T Type of the coordinates
*/
template <typename T>
struct multipoint_generator_parameter {
using element_t = T;

std::size_t num_multipoints;
std::size_t num_points_per_multipoints;
vec_2d<T> lower_left;
vec_2d<T> upper_right;

CUSPATIAL_HOST_DEVICE std::size_t num_points()
{
return num_multipoints * num_points_per_multipoints;
}
};

/**
* @brief Helper to generate random multipoints within a range
*
* @tparam T The floating point type for the coordinates
* @param params Parameters to specify for the multipoints
* @param stream The CUDA stream to use for device memory operations and kernel launches
* @return a cuspatial::test::multipoint_array object
*/
template <typename T>
auto generate_multipoint_array(multipoint_generator_parameter<T> params,
rmm::cuda_stream_view stream)
{
rmm::device_uvector<vec_2d<T>> coordinates(params.num_points(), stream);
rmm::device_uvector<std::size_t> offsets(params.num_multipoints + 1, stream);

thrust::sequence(rmm::exec_policy(stream),
offsets.begin(),
offsets.end(),
std::size_t{0},
params.num_points_per_multipoints);

auto engine_x = deterministic_engine(params.num_points());
auto engine_y = deterministic_engine(2 * params.num_points());

auto x_dist = make_uniform_dist(params.lower_left.x, params.upper_right.x);
auto y_dist = make_uniform_dist(params.lower_left.y, params.upper_right.y);

auto point_gen =
point_generator(params.lower_left, params.upper_right, engine_x, engine_y, x_dist, y_dist);

thrust::tabulate(rmm::exec_policy(stream), coordinates.begin(), coordinates.end(), point_gen);

return make_multipoint_array(std::move(offsets), std::move(coordinates));
}

} // namespace test
} // namespace cuspatial
17 changes: 12 additions & 5 deletions cpp/include/cuspatial_test/random.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,21 @@ struct value_generator {
template <typename T, typename Generator>
struct point_generator {
using Cart2D = cuspatial::vec_2d<T>;
value_generator<T, Generator> vgen;

point_generator(T lower_bound, T upper_bound, thrust::minstd_rand& engine, Generator gen)
: vgen(lower_bound, upper_bound, engine, gen)
value_generator<T, Generator> vgenx;
value_generator<T, Generator> vgeny;

point_generator(vec_2d<T> lower_left,
vec_2d<T> upper_right,
thrust::minstd_rand& engine_x,
thrust::minstd_rand& engine_y,
Generator gen_x,
Generator gen_y)
: vgenx(lower_left.x, upper_right.x, engine_x, gen_x),
vgeny(lower_left.y, upper_right.y, engine_y, gen_y)
{
}

__device__ Cart2D operator()(size_t n) { return {vgen(n), vgen(n)}; }
__device__ Cart2D operator()(size_t n) { return {vgenx(n), vgeny(n)}; }
};

/**
Expand Down
45 changes: 37 additions & 8 deletions cpp/include/cuspatial_test/vector_factories.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,22 @@ auto make_multilinestring_array(std::initializer_list<std::size_t> geometry_inl,
template <typename GeometryArray, typename CoordinateArray>
class multipoint_array {
public:
multipoint_array(GeometryArray geometry_offsets_array, CoordinateArray coordinate_array)
using geometry_t = typename GeometryArray::value_type;
using coord_t = typename CoordinateArray::value_type;

multipoint_array(thrust::device_vector<geometry_t> geometry_offsets_array,
thrust::device_vector<coord_t> coordinate_array)
: _geometry_offsets(geometry_offsets_array), _coordinates(coordinate_array)
{
}

multipoint_array(rmm::device_uvector<geometry_t>&& geometry_offsets_array,
rmm::device_uvector<coord_t>&& coordinate_array)
: _geometry_offsets(std::move(geometry_offsets_array)),
_coordinates(std::move(coordinate_array))
{
}

/// Return the number of multipoints
auto size() { return _geometry_offsets.size() - 1; }

Expand All @@ -337,27 +348,33 @@ class multipoint_array {
* coordinates
*/
template <typename GeometryRange, typename CoordRange>
auto make_multipoints_array(GeometryRange geometry_inl, CoordRange coordinates_inl)
auto make_multipoint_array(GeometryRange geometry_inl, CoordRange coordinates_inl)
{
return multipoint_array{make_device_vector(geometry_inl), make_device_vector(coordinates_inl)};
using IndexType = typename GeometryRange::value_type;
using CoordType = typename CoordRange::value_type;
using DeviceIndexVector = thrust::device_vector<IndexType>;
using DeviceCoordVector = thrust::device_vector<CoordType>;

return multipoint_array<DeviceIndexVector, DeviceCoordVector>{
make_device_vector(geometry_inl), make_device_vector(coordinates_inl)};
}

/**
* @brief Factory method to construct multipoint array from initializer list of multipoints.
*
* Example: Construct an array of 2 multipoints, each with 2, 0, 1 points:
* using P = vec_2d<float>;
* make_multipoints_array({{P{0.0, 1.0}, P{2.0, 0.0}}, {}, {P{3.0, 4.0}}});
* make_multipoint_array({{P{0.0, 1.0}, P{2.0, 0.0}}, {}, {P{3.0, 4.0}}});
*
* Example: Construct an empty multilinestring array:
* make_multipoints_array<float>({}); // Explicit parameter required to deduce type.
* make_multipoint_array<float>({}); // Explicit parameter required to deduce type.
*
* @tparam T Type of coordinate
* @param inl List of multipoints
* @return multipoints_array object
*/
template <typename T>
auto make_multipoints_array(std::initializer_list<std::initializer_list<vec_2d<T>>> inl)
auto make_multipoint_array(std::initializer_list<std::initializer_list<vec_2d<T>>> inl)
{
std::vector<std::size_t> offsets{0};
std::transform(inl.begin(), inl.end(), std::back_inserter(offsets), [](auto multipoint) {
Expand All @@ -371,8 +388,20 @@ auto make_multipoints_array(std::initializer_list<std::initializer_list<vec_2d<T
return init;
});

return multipoint_array{rmm::device_vector<std::size_t>(offsets),
rmm::device_vector<vec_2d<T>>(coordinates)};
return multipoint_array<rmm::device_vector<std::size_t>, rmm::device_vector<vec_2d<T>>>{
rmm::device_vector<std::size_t>(offsets), rmm::device_vector<vec_2d<T>>(coordinates)};
}

/**
* @brief Factory method to construct multipoint array by moving the offsets and coordinates from
* `rmm::device_uvector`.
*/
template <typename IndexType, typename T>
auto make_multipoint_array(rmm::device_uvector<IndexType> geometry_offsets,
rmm::device_uvector<vec_2d<T>> coords)
{
return multipoint_array<rmm::device_uvector<std::size_t>, rmm::device_uvector<vec_2d<T>>>{
std::move(geometry_offsets), std::move(coords)};
}

} // namespace test
Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/distance/point_distance_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
* limitations under the License.
*/

#include <cuspatial_test/vector_equality.hpp>

#include <cuspatial_test/random.cuh>
#include <cuspatial_test/vector_equality.hpp>

#include <cuspatial/distance.cuh>
#include <cuspatial/error.hpp>
Expand Down Expand Up @@ -59,7 +58,8 @@ struct PairwisePointDistanceTest : public ::testing::Test {
{
auto engine = cuspatial::test::deterministic_engine(0);
auto uniform = cuspatial::test::make_normal_dist<T>(0.0, 1.0);
auto pgen = cuspatial::test::point_generator(T{0.0}, T{1.0}, engine, uniform);
auto pgen = cuspatial::test::point_generator(
vec_2d<T>{0.0, 0.0}, vec_2d<T>{1.0, 1.0}, engine, engine, uniform, uniform);
rmm::device_vector<vec_2d<T>> points(num_points);
auto counting_iter = thrust::make_counting_iterator(seed);
thrust::transform(
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/distance/point_polygon_distance_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct PairwisePointPolygonDistanceTest : public ::testing::Test {
std::vector<vec_2d<T>> const& multipolygon_coordinates,
std::initializer_list<T> expected)
{
auto d_multipoints = make_multipoints_array(multipoints);
auto d_multipoints = make_multipoint_array(multipoints);
auto d_multipolygons = make_multipolygon_array(
range{multipolygon_geometry_offsets.begin(), multipolygon_geometry_offsets.end()},
range{multipolygon_part_offsets.begin(), multipolygon_part_offsets.end()},
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/equality/pairwise_multipoint_equals_count_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct PairwiseMultipointEqualsCountTest : public BaseFixture {
std::initializer_list<std::initializer_list<vec_2d<T>>> rhs_coordinates,
std::initializer_list<uint32_t> expected)
{
auto larray = make_multipoints_array(lhs_coordinates);
auto rarray = make_multipoints_array(rhs_coordinates);
auto larray = make_multipoint_array(lhs_coordinates);
auto rarray = make_multipoint_array(rhs_coordinates);

auto lhs = larray.range();
auto rhs = rarray.range();
Expand Down
Loading

0 comments on commit 7f42eeb

Please sign in to comment.