Skip to content

Commit

Permalink
small cleanups.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 16, 2024
1 parent d0d6880 commit 8253481
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
11 changes: 6 additions & 5 deletions src/encoder/ordinal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct SegmentedSearchSortedStrOp {
return l_str < r_str;
});
if (ret_it == it + f_sorted_idx.size()) {
return -1; // not found
return detail::NotFound();
}
return *ret_it;
}
Expand Down Expand Up @@ -242,7 +242,7 @@ void Recode(ExecPolicy const& policy, DeviceColumnsView orig_enc,
orig_enc.feature_segments[f_idx + 1] - orig_enc.feature_segments[f_idx]);

std::int32_t idx = -1;
if (searched_idx != -1) {
if (searched_idx != detail::NotFound()) {
idx = f_sorted_idx[searched_idx];
}

Expand All @@ -252,9 +252,10 @@ void Recode(ExecPolicy const& policy, DeviceColumnsView orig_enc,
f_mapping[i - f_beg] = idx;
});

auto err_it = thrust::find_if(
exec, dh::tcbegin(mapping), dh::tcend(mapping),
cuda::proclaim_return_type<bool>([=] __device__(std::int32_t v) { return v == -1; }));
auto err_it = thrust::find_if(exec, dh::tcbegin(mapping), dh::tcend(mapping),
cuda::proclaim_return_type<bool>([=] __device__(std::int32_t v) {
return v == detail::NotFound();
}));

if (err_it != dh::tcend(mapping)) {
// Report missing cat.
Expand Down
13 changes: 7 additions & 6 deletions src/encoder/ordinal.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* strings as constant). Originally, the encoding for test data set is [0, 1, 2] for [c,
* a, b], now we have a mapping {0 -> 0, 1 -> 3, 2 -> 1} for re-coding the data.
*
* This module exposes 2 functions and an exection policy:
* This module exposes 2 functions and an execution policy:
* - @ref Recode
* - @ref SortNames
* Each of them has a device counterpart.
Expand Down Expand Up @@ -99,7 +99,7 @@ using DeviceCatIndexView = cuda_impl::TupToVarT<CatIndexViewTypes>;
* Accepted policies:
*
* - A class with a `ThrustPolicy` method that returns a thrust execution policy, along with a
* `ThrustAllocator` template type.
* `ThrustAllocator` template type. This is only used for the GPU implementation.
*
* - An error handling policy that exposes a single `Error` method, which takes a single
* string parameter for error message.
Expand All @@ -109,6 +109,7 @@ struct Policy : public Derived... {};

namespace detail {
constexpr std::int32_t SearchKey() { return -1; }
constexpr std::int32_t NotFound() { return -1; }

template <typename Variant>
struct ColumnsViewImpl {
Expand Down Expand Up @@ -235,7 +236,7 @@ void ArgSort(InIt in_first, InIt in_last, OutIt out_first, Comp comp = std::less
return l_str < r_str;
});
if (ret_it == it + haystack.size()) {
return -1;
return detail::NotFound();
}
return *ret_it;
}
Expand All @@ -251,7 +252,7 @@ SearchSorted(Span<T const> haystack, Span<std::int32_t const> ref_sorted_idx, T
return l_value < r_value;
});
if (ret_it == it + haystack.size()) {
return -1;
return detail::NotFound();
}
return *ret_it;
}
Expand Down Expand Up @@ -352,7 +353,7 @@ void Recode(ExecPolicy const &policy, HostColumnsView orig_enc, Span<std::int32_
searched_idx[j - 1] = cpu_impl::SearchSorted(
std::get<CatStrArrayView>(orig_enc.columns[f_idx]),
ref_sorted_idx, needle);
if (searched_idx[j - 1] == -1) {
if (searched_idx[j - 1] == detail::NotFound()) {
std::stringstream ss;
for (auto c : needle) {
ss.put(c);
Expand All @@ -368,7 +369,7 @@ void Recode(ExecPolicy const &policy, HostColumnsView orig_enc, Span<std::int32_
searched_idx[j] = cpu_impl::SearchSorted(
std::get<Span<std::add_const_t<T>>>(orig_enc.columns[f_idx]),
ref_sorted_idx, needle);
if (searched_idx[j] == -1) {
if (searched_idx[j] == detail::NotFound()) {
std::stringstream ss;
ss << needle;
detail::ReportMissing(policy, ss.str(), f_idx);
Expand Down

0 comments on commit 8253481

Please sign in to comment.