Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtrombetta committed Sep 27, 2024
1 parent b373a39 commit 7af6d8f
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 248 deletions.
105 changes: 53 additions & 52 deletions sxt/algorithm/block/runlength_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,67 +34,68 @@
* https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
*/
#include <__config>

#define _VSTD std::_LIBCPP_ABI_NAMESPACE

#include "cub/cub.cuh"
#include "sxt/base/macro/cuda_callable.h"

namespace sxt::algbk {
//--------------------------------------------------------------------------------------------------
// runlength_count
//--------------------------------------------------------------------------------------------------
/**
* This is adapted from block_histogram_sort in CUB. See
* https://github.com/NVIDIA/cccl/blob/a51b1f8c75f8e577eeccc74b45f1ff16a2727265/cub/cub/block/specializations/block_histogram_sort.cuh
*/
template <class T, class CounterT, unsigned NumThreads, unsigned NumBins> class runlength_count {
using Discontinuity = cub::BlockDiscontinuity<T, NumThreads>;
namespace sxt::algbk {
//--------------------------------------------------------------------------------------------------
// runlength_count
//--------------------------------------------------------------------------------------------------
/**
* This is adapted from block_histogram_sort in CUB. See
* https://github.com/NVIDIA/cccl/blob/a51b1f8c75f8e577eeccc74b45f1ff16a2727265/cub/cub/block/specializations/block_histogram_sort.cuh
*/
template <class T, class CounterT, unsigned NumThreads, unsigned NumBins> class runlength_count {
using Discontinuity = cub::BlockDiscontinuity<T, NumThreads>;

public:
struct temp_storage {
typename Discontinuity::TempStorage discontinuity;
CounterT run_begin[NumBins];
CounterT run_end[NumBins];
};
public:
struct temp_storage {
typename Discontinuity::TempStorage discontinuity;
CounterT run_begin[NumBins];
CounterT run_end[NumBins];
};

CUDA_CALLABLE explicit runlength_count(temp_storage& storage) noexcept
: storage_{storage}, discontinuity_{storage.discontinuity} {}
CUDA_CALLABLE explicit runlength_count(temp_storage& storage) noexcept
: storage_{storage}, discontinuity_{storage.discontinuity} {}

/**
* If items holds sorted items across threads in a block, count and return a
* pointer to a table of the items' run lengths.
*/
template <unsigned ItemsPerThread>
CUDA_CALLABLE CounterT* count(T (&items)[ItemsPerThread]) noexcept {
auto thread_id = threadIdx.x;
for (unsigned i = thread_id; i < NumBins; i += NumThreads) {
storage_.run_begin[i] = NumThreads * ItemsPerThread;
storage_.run_end[i] = NumThreads * ItemsPerThread;
}
int flags[ItemsPerThread];
auto flag_op = [&storage = storage_](T a, T b, int b_index) noexcept {
if (a != b) {
storage.run_begin[b] = static_cast<CounterT>(b_index);
storage.run_end[a] = static_cast<CounterT>(b_index);
return true;
} else {
return false;
}
};
__syncthreads();
discontinuity_.FlagHeads(flags, items, flag_op);
if (thread_id == 0) {
storage_.run_begin[items[0]] = 0;
}
__syncthreads();
for (unsigned i = thread_id; i < NumBins; i += NumThreads) {
storage_.run_end[i] -= storage_.run_begin[i];
/**
* If items holds sorted items across threads in a block, count and return a
* pointer to a table of the items' run lengths.
*/
template <unsigned ItemsPerThread>
CUDA_CALLABLE CounterT* count(T (&items)[ItemsPerThread]) noexcept {
auto thread_id = threadIdx.x;
for (unsigned i = thread_id; i < NumBins; i += NumThreads) {
storage_.run_begin[i] = NumThreads * ItemsPerThread;
storage_.run_end[i] = NumThreads * ItemsPerThread;
}
int flags[ItemsPerThread];
auto flag_op = [&storage = storage_](T a, T b, int b_index) noexcept {
if (a != b) {
storage.run_begin[b] = static_cast<CounterT>(b_index);
storage.run_end[a] = static_cast<CounterT>(b_index);
return true;
} else {
return false;
}
return storage_.run_end;
};
__syncthreads();
discontinuity_.FlagHeads(flags, items, flag_op);
if (thread_id == 0) {
storage_.run_begin[items[0]] = 0;
}
__syncthreads();
for (unsigned i = thread_id; i < NumBins; i += NumThreads) {
storage_.run_end[i] -= storage_.run_begin[i];
}
return storage_.run_end;
}

private:
temp_storage& storage_;
Discontinuity discontinuity_;
};
private:
temp_storage& storage_;
Discontinuity discontinuity_;
};
} // namespace sxt::algbk
164 changes: 82 additions & 82 deletions sxt/multiexp/base/scalar_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,100 +50,100 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"

namespace sxt::mtxb {
//--------------------------------------------------------------------------------------------------
// scalar_blob
//--------------------------------------------------------------------------------------------------
namespace {
template <unsigned NumBytes> struct scalar_blob {
uint8_t data[NumBytes];
};
} // namespace
namespace sxt::mtxb {
//--------------------------------------------------------------------------------------------------
// scalar_blob
//--------------------------------------------------------------------------------------------------
namespace {
template <unsigned NumBytes> struct scalar_blob {
uint8_t data[NumBytes];
};
} // namespace

//--------------------------------------------------------------------------------------------------
// transpose_kernel
//--------------------------------------------------------------------------------------------------
template <unsigned NumBytes>
static __global__ void transpose_kernel(uint8_t* __restrict__ dst,
const scalar_blob<NumBytes>* __restrict__ src,
unsigned n) noexcept {
using Scalar = scalar_blob<NumBytes>;
//--------------------------------------------------------------------------------------------------
// transpose_kernel
//--------------------------------------------------------------------------------------------------
template <unsigned NumBytes>
static __global__ void transpose_kernel(uint8_t* __restrict__ dst,
const scalar_blob<NumBytes>* __restrict__ src,
unsigned n) noexcept {
using Scalar = scalar_blob<NumBytes>;

auto byte_index = threadIdx.x;
auto tile_index = blockIdx.x;
auto output_index = blockIdx.y;
auto num_tiles = gridDim.x;
auto n_per_tile = basn::divide_up(basn::divide_up(n, NumBytes), num_tiles) * NumBytes;
auto byte_index = threadIdx.x;
auto tile_index = blockIdx.x;
auto output_index = blockIdx.y;
auto num_tiles = gridDim.x;
auto n_per_tile = basn::divide_up(basn::divide_up(n, NumBytes), num_tiles) * NumBytes;

auto first = tile_index * n_per_tile;
auto m = min(n_per_tile, n - first);
auto first = tile_index * n_per_tile;
auto m = min(n_per_tile, n - first);

// adjust pointers
src += first;
src += output_index * n;
dst += byte_index * n + first;
dst += output_index * NumBytes * n;
// adjust pointers
src += first;
src += output_index * n;
dst += byte_index * n + first;
dst += output_index * NumBytes * n;

// set up algorithm
using BlockExchange = cub::BlockExchange<uint8_t, NumBytes, NumBytes>;
__shared__ typename BlockExchange::TempStorage temp_storage;
// set up algorithm
using BlockExchange = cub::BlockExchange<uint8_t, NumBytes, NumBytes>;
__shared__ typename BlockExchange::TempStorage temp_storage;

// transpose
Scalar s;
unsigned out_first = 0;
for (unsigned i = byte_index; i < n_per_tile; i += NumBytes) {
if (i < m) {
s = src[i];
}
BlockExchange(temp_storage).StripedToBlocked(s.data);
__syncthreads();
for (unsigned j = 0; j < NumBytes; ++j) {
auto out_index = out_first + j;
if (out_index < m) {
dst[out_index] = s.data[j];
}
// transpose
Scalar s;
unsigned out_first = 0;
for (unsigned i = byte_index; i < n_per_tile; i += NumBytes) {
if (i < m) {
s = src[i];
}
BlockExchange(temp_storage).StripedToBlocked(s.data);
__syncthreads();
for (unsigned j = 0; j < NumBytes; ++j) {
auto out_index = out_first + j;
if (out_index < m) {
dst[out_index] = s.data[j];
}
out_first += NumBytes;
__syncthreads();
}
out_first += NumBytes;
__syncthreads();
}
}

//--------------------------------------------------------------------------------------------------
// transpose_scalars_to_device
//--------------------------------------------------------------------------------------------------
xena::future<> transpose_scalars_to_device(basct::span<uint8_t> array,
basct::cspan<const uint8_t*> scalars,
unsigned element_num_bytes, unsigned n) noexcept {
auto num_outputs = static_cast<unsigned>(scalars.size());
if (n == 0 || num_outputs == 0) {
co_return;
}
SXT_DEBUG_ASSERT(
// clang-format off
//--------------------------------------------------------------------------------------------------
// transpose_scalars_to_device
//--------------------------------------------------------------------------------------------------
xena::future<> transpose_scalars_to_device(basct::span<uint8_t> array,
basct::cspan<const uint8_t*> scalars,
unsigned element_num_bytes, unsigned n) noexcept {
auto num_outputs = static_cast<unsigned>(scalars.size());
if (n == 0 || num_outputs == 0) {
co_return;
}
SXT_DEBUG_ASSERT(
// clang-format off
array.size() == num_outputs * element_num_bytes * n &&
basdv::is_active_device_pointer(array.data()) &&
basdv::is_host_pointer(scalars[0])
// clang-format on
);
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<uint8_t> array_p{array.size(), &resource};
auto num_bytes_per_output = element_num_bytes * n;
for (size_t output_index = 0; output_index < num_outputs; ++output_index) {
basdv::async_copy_host_to_device(
basct::subspan(array_p, output_index * num_bytes_per_output, num_bytes_per_output),
basct::cspan<uint8_t>{scalars[output_index], num_bytes_per_output}, stream);
}
auto num_tiles = std::min(basn::divide_up(n, num_outputs * element_num_bytes), 64u);
auto num_bytes_log2 = basn::ceil_log2(element_num_bytes);
basn::constexpr_switch<6>(
num_bytes_log2,
[&]<unsigned LogNumBytes>(std::integral_constant<unsigned, LogNumBytes>) noexcept {
constexpr auto NumBytes = 1u << LogNumBytes;
SXT_DEBUG_ASSERT(NumBytes == element_num_bytes);
transpose_kernel<NumBytes><<<dim3(num_tiles, num_outputs, 1), NumBytes, 0, stream>>>(
array.data(), reinterpret_cast<scalar_blob<NumBytes>*>(array_p.data()), n);
});
co_await xendv::await_stream(std::move(stream));
// clang-format on
);
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<uint8_t> array_p{array.size(), &resource};
auto num_bytes_per_output = element_num_bytes * n;
for (size_t output_index = 0; output_index < num_outputs; ++output_index) {
basdv::async_copy_host_to_device(
basct::subspan(array_p, output_index * num_bytes_per_output, num_bytes_per_output),
basct::cspan<uint8_t>{scalars[output_index], num_bytes_per_output}, stream);
}
auto num_tiles = std::min(basn::divide_up(n, num_outputs * element_num_bytes), 64u);
auto num_bytes_log2 = basn::ceil_log2(element_num_bytes);
basn::constexpr_switch<6>(
num_bytes_log2,
[&]<unsigned LogNumBytes>(std::integral_constant<unsigned, LogNumBytes>) noexcept {
constexpr auto NumBytes = 1u << LogNumBytes;
SXT_DEBUG_ASSERT(NumBytes == element_num_bytes);
transpose_kernel<NumBytes><<<dim3(num_tiles, num_outputs, 1), NumBytes, 0, stream>>>(
array.data(), reinterpret_cast<scalar_blob<NumBytes>*>(array_p.data()), n);
});
co_await xendv::await_stream(std::move(stream));
}
} // namespace sxt::mtxb
Loading

0 comments on commit 7af6d8f

Please sign in to comment.