Skip to content

add counting sort example #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions examples/hlsl/counting_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// The entry point and target profile are needed to compile this example:
// -T ps_6_6 -E PSMain

#include "nbl/builtin/hlsl/sort/counting.hlsl"
#include "nbl/builtin/hlsl/bda/bda_accessor.hlsl"

#define BucketCount 27
#define WorkgroupSize 27

struct CountingPushData
{
uint64_t inputKeyAddress;
uint64_t inputValueAddress;
uint64_t histogramAddress;
uint64_t outputKeyAddress;
uint64_t outputValueAddress;
uint32_t dataElementCount;
uint32_t elementsPerWT;
uint32_t minimum;
uint32_t maximum;
};

using namespace nbl::hlsl;

using Ptr = bda::__ptr<uint32_t>;
using PtrAccessor = BdaAccessor<uint32_t>;

groupshared uint32_t sdata[BucketCount];

struct SharedAccessor
{
void get(const uint32_t index, NBL_REF_ARG(uint32_t) value)
{
value = sdata[index];
}

void set(const uint32_t index, const uint32_t value)
{
sdata[index] = value;
}

uint32_t atomicAdd(const uint32_t index, const uint32_t value)
{
return glsl::atomicAdd(sdata[index], value);
}

void workgroupExecutionAndMemoryBarrier()
{
glsl::barrier();
}
};

uint32_t3 glsl::gl_WorkGroupSize()
{
return uint32_t3(WorkgroupSize, 1, 1);
}

[[vk::push_constant]] CountingPushData pushData;

using DoublePtrAccessor = DoubleBdaAccessor<uint32_t>;

[[vk::push_constant]] CountingPushData pushData;

[numthreads(WorkgroupSize,1,1)]
void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
{
sort::CountingParameters < uint32_t > params;
params.dataElementCount = pushData.dataElementCount;
params.elementsPerWT = pushData.elementsPerWT;
params.minimum = pushData.minimum;
params.maximum = pushData.maximum;

using Counter = sort::counting<WorkgroupSize, BucketCount, PtrAccessor, PtrAccessor, PtrAccessor, SharedAccessor, PtrAccessor::type_t>;
Counter counter = Counter::create(glsl::gl_WorkGroupID().x);

const Ptr input_ptr = Ptr::create(pushData.inputKeyAddress);
const Ptr histogram_ptr = Ptr::create(pushData.histogramAddress);

PtrAccessor input_accessor = PtrAccessor::create(input_ptr);
PtrAccessor histogram_accessor = PtrAccessor::create(histogram_ptr);
SharedAccessor shared_accessor;
counter.histogram(
input_accessor,
histogram_accessor,
shared_accessor,
params
);
}