Skip to content

Commit

Permalink
implement omp on demand tls
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Hoeflinger <[email protected]>
  • Loading branch information
danhoeflinger committed Dec 30, 2024
1 parent 6b5019d commit 95d483f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 22 deletions.
4 changes: 2 additions & 2 deletions include/oneapi/dpl/pstl/algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4348,13 +4348,13 @@ __pattern_histogram(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _Rando
[__histogram_first, &__tls](auto __global_histogram_first,
auto __global_histogram_last) {
_DiffType __local_n = __global_histogram_last - __global_histogram_first;
std::size_t __num_temporary_copies = __tls.size();
std::uint32_t __num_temporary_copies = __tls.size();
_DiffType __range_begin_id = __global_histogram_first - __histogram_first;
//initialize output global histogram with first local histogram via assign
__internal::__brick_walk2_n(__tls.get_with_id(0).begin() + __range_begin_id, __local_n,
__global_histogram_first, oneapi::dpl::__internal::__pstl_assign(),
_IsVector{});
for (std::size_t __i = 1; __i < __num_temporary_copies; ++__i)
for (std::uint32_t __i = 1; __i < __num_temporary_copies; ++__i)
{
//accumulate into output global histogram with other local histogram via += operator
__internal::__brick_walk2_n(
Expand Down
64 changes: 49 additions & 15 deletions include/oneapi/dpl/pstl/omp/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>
#include <type_traits>
#include <omp.h>
#include <atomic>

#include "../parallel_backend_utils.h"
#include "../unseq_backend_simd.h"
Expand Down Expand Up @@ -153,45 +154,78 @@ __process_chunk(const __chunk_metrics& __metrics, _Iterator __base, _Index __chu
__f(__first, __last);
}

template<typename _StorageType>
class __construct_by_args_base {
public:
virtual ~__construct_by_args_base() { }
virtual std::unique_ptr<_StorageType> construct() = 0;
};

template<typename _StorageType, typename... _P>
class __construct_by_args: public __construct_by_args_base<_StorageType>{
public:
std::unique_ptr<_StorageType> construct() {
return std::move(std::apply([](auto... __arg_pack) { return std::make_unique<_StorageType>(__arg_pack...);}, pack));
}
__construct_by_args( _P&& ... args ) : pack(std::forward<_P>(args)...) {}
private:
std::tuple<_P...> pack;
};

template <typename _StorageType>
struct __thread_enumerable_storage
{
template <typename... Args>
__thread_enumerable_storage(Args&&... args)
__thread_enumerable_storage(Args&&... args) : __num_elements(0)
{
__construct_helper = std::make_unique<__construct_by_args<_StorageType, Args...>>(std::forward<Args>(args)...);
_PSTL_PRAGMA(omp parallel)
_PSTL_PRAGMA(omp single nowait)
_PSTL_PRAGMA(omp single)
{
__num_threads = omp_get_num_threads();
__thread_specific_storage.resize(__num_threads);
_PSTL_PRAGMA(omp taskloop shared(__thread_specific_storage))
for (std::size_t __tid = 0; __tid < __num_threads; ++__tid)
{
__thread_specific_storage[__tid] = std::make_unique<_StorageType>(std::forward<Args>(args)...);
}
__thread_specific_storage.resize(omp_get_num_threads());
}
}

std::size_t
std::uint32_t
size() const
{
return __num_threads;
return __num_elements.load();
}

_StorageType&
get_with_id(std::size_t __i)
get_with_id(std::uint32_t __i)
{
return *__thread_specific_storage[__i];
if (__i < size())
{
std::uint32_t __count = 0;
std::uint32_t __j = 0;
for (; __j < __thread_specific_storage.size() && __count <= __i; ++__j)
{
if (__thread_specific_storage[__j])
{
__count++;
}
}
// Need to back up one once we have found a valid element
return *__thread_specific_storage[__j-1];
}
}

_StorageType&
get()
{
return get_with_id(omp_get_thread_num());
std::uint32_t __i = omp_get_thread_num();
if (!__thread_specific_storage[__i])
{
__thread_specific_storage[__i] = __construct_helper->construct();
__num_elements.fetch_add(1);
}
return *__thread_specific_storage[__i];
}

std::vector<std::unique_ptr<_StorageType>> __thread_specific_storage;
std::size_t __num_threads;
std::atomic<std::uint32_t> __num_elements;
std::unique_ptr<__construct_by_args_base<_StorageType>> __construct_helper;
};

} // namespace __omp_backend
Expand Down
6 changes: 3 additions & 3 deletions include/oneapi/dpl/pstl/parallel_backend_serial.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ struct __thread_enumerable_storage
{
}

std::size_t
std::uint32_t
size() const
{
return std::size_t{1};
return std::uint32_t{1};
}

_StorageType&
Expand All @@ -63,7 +63,7 @@ struct __thread_enumerable_storage
}

_StorageType&
get_with_id(std::size_t __i)
get_with_id(std::uint32_t __i)
{
return get();
}
Expand Down
4 changes: 2 additions & 2 deletions include/oneapi/dpl/pstl/parallel_backend_tbb.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ struct __thread_enumerable_storage
{
}

std::size_t
std::uint32_t
size() const
{
return __thread_specific_storage.size();
Expand All @@ -1329,7 +1329,7 @@ struct __thread_enumerable_storage
}

_StorageType&
get_with_id(std::size_t __i)
get_with_id(std::uint32_t __i)
{
return __thread_specific_storage.begin()[__i];
}
Expand Down

0 comments on commit 95d483f

Please sign in to comment.