Skip to content

Commit

Permalink
fallback buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
swfly committed Nov 27, 2024
1 parent b74da1e commit b9a9599
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 42 deletions.
1 change: 0 additions & 1 deletion src/backends/fallback/fallback_accel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,4 @@ void accel_transform_wrapper(void* accel, unsigned id, void* buffer)
luisa::compute::fallback::detail::fill_transform(
reinterpret_cast<luisa::compute::fallback::FallbackAccel *>(accel),
id, reinterpret_cast<luisa::float4x4 *>(buffer));
auto g = reinterpret_cast<luisa::float4x4 *>(buffer);
}
15 changes: 13 additions & 2 deletions src/backends/fallback/fallback_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,19 @@
namespace luisa::compute::fallback {

FallbackBufferView FallbackBuffer::view(size_t offset) noexcept {
LUISA_ASSERT(offset <= data.size(), "Buffer view out of range.");
return {static_cast<void *>(data.data() + offset), data.size() - offset};
LUISA_ASSERT(offset <= size, "Buffer view out of range.");
return {static_cast<void *>(data + offset * elementStride), size-offset};
}

FallbackBuffer::FallbackBuffer(size_t size, unsigned elementStride):elementStride(elementStride), size(size)
{
data = luisa::allocate_with_allocator<std::byte>(size * elementStride);

}

FallbackBuffer::~FallbackBuffer()
{
luisa::deallocate_with_allocator(data);
}

}// namespace luisa::compute::fallback
10 changes: 6 additions & 4 deletions src/backends/fallback/fallback_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ struct alignas(16) FallbackBufferView {

class FallbackBuffer {
public:
// FIXME: size
void *addr() { return data.data(); }
explicit FallbackBuffer(size_t size, unsigned elementStride);
void *addr() { return data; }
[[nodiscard]] FallbackBufferView view(size_t offset) noexcept;

~FallbackBuffer();
private:

std::vector<std::byte> data{};
unsigned elementStride;
unsigned size;
std::byte* data{};
};

}// namespace luisa::compute::fallback
7 changes: 4 additions & 3 deletions src/backends/fallback/fallback_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "fallback_accel.h"
#include "fallback_bindless_array.h"
#include "fallback_shader.h"
#include "fallback_buffer.h"

//#include "llvm_event.h"
//#include "llvm_shader.h"
Expand Down Expand Up @@ -70,7 +71,7 @@ void *FallbackDevice::native_handle() const noexcept {
}

void FallbackDevice::destroy_buffer(uint64_t handle) noexcept {
luisa::deallocate_with_allocator(reinterpret_cast<std::byte *>(handle));
luisa::deallocate_with_allocator(reinterpret_cast<FallbackBuffer *>(handle));
}

void FallbackDevice::destroy_texture(uint64_t handle) noexcept {
Expand Down Expand Up @@ -136,8 +137,8 @@ BufferCreationInfo FallbackDevice::create_buffer(const Type *element, size_t ele
info.element_stride = element->size();
}
info.total_size_bytes = info.element_stride * elem_count;
info.handle = reinterpret_cast<uint64_t>(
luisa::allocate_with_allocator<std::byte>(info.total_size_bytes));
auto buffer = luisa::new_with_allocator<FallbackBuffer>(elem_count, info.element_stride);
info.handle = reinterpret_cast<uint64_t>(buffer);
info.native_handle = reinterpret_cast<void *>(info.handle);
return info;
}
Expand Down
44 changes: 17 additions & 27 deletions src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,11 @@ void compute::fallback::FallbackShader::dispatch(ThreadPool &pool, const compute
using Tag = ShaderDispatchCommand::Argument::Tag;
switch (arg.tag) {
case Tag::BUFFER: {
//What is indirect?
// if (reinterpret_cast<const CUDABufferBase *>(arg.buffer.handle)->is_indirect())
// {
// auto buffer = reinterpret_cast<const CUDAIndirectDispatchBuffer *>(arg.buffer.handle);
// auto binding = buffer->binding(arg.buffer.offset, arg.buffer.size);
// auto ptr = allocate_argument(sizeof(binding));
// std::memcpy(ptr, &binding, sizeof(binding));
// }
// else
{
auto buffer = reinterpret_cast<FallbackBuffer *>(arg.buffer.handle);
auto buffer_view = buffer->view(arg.buffer.offset);
//auto binding = buffer->binding(arg.buffer.offset, arg.buffer.size);
auto ptr = allocate_argument(sizeof(buffer_view));
std::memcpy(ptr, &buffer, sizeof(buffer_view));
std::memcpy(ptr, &buffer_view, sizeof(buffer_view));
}
break;
}
Expand Down Expand Up @@ -282,23 +272,23 @@ void compute::fallback::FallbackShader::dispatch(ThreadPool &pool, const compute

auto data = argument_buffer.data();

// for (int i = 0; i < dispatch_counts.x; ++i) {
// for (int j = 0; j < dispatch_counts.y; ++j) {
// for (int k = 0; k < dispatch_counts.z; ++k) {
// auto c = config;
// c.block_id = make_uint3(i, j, k);
// (*_kernel_entry)(data, &c);
// }
// }
// }
for (int i = 0; i < dispatch_counts.x; ++i) {
for (int j = 0; j < dispatch_counts.y; ++j) {
for (int k = 0; k < dispatch_counts.z; ++k) {
auto c = config;
c.block_id = make_uint3(i, j, k);
(*_kernel_entry)(data, &c);
}
}
}

pool.parallel(dispatch_counts.x, dispatch_counts.y, dispatch_counts.z,
[this, config, data](auto bx, auto by, auto bz) noexcept {
auto c = config;
c.block_id = make_uint3(bx, by, bz);
(*_kernel_entry)(data, &c);
});
pool.barrier();
// pool.parallel(dispatch_counts.x, dispatch_counts.y, dispatch_counts.z,
// [this, config, data](auto bx, auto by, auto bz) noexcept {
// auto c = config;
// c.block_id = make_uint3(bx, by, bz);
// (*_kernel_entry)(data, &c);
// });
// pool.barrier();
}
void compute::fallback::FallbackShader::build_bound_arguments(compute::Function kernel) {
_bound_arguments.reserve(kernel.bound_arguments().size());
Expand Down
11 changes: 6 additions & 5 deletions src/backends/fallback/fallback_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "fallback_mesh.h"
#include "fallback_texture.h"
#include "fallback_shader.h"
#include "fallback_buffer.h"

namespace luisa::compute::fallback {

Expand Down Expand Up @@ -45,24 +46,24 @@ void FallbackStream::visit(const BufferUploadCommand *command) noexcept {
std::memcpy(temp_buffer->data(), command->data(), command->size());
_pool.async([src = std::move(temp_buffer),
buffer = command->handle(), offset = command->offset()] {
auto dst = reinterpret_cast<void *>(buffer + offset);
auto dst = reinterpret_cast<FallbackBuffer*>(buffer)->view(offset).ptr;
std::memcpy(dst, src->data(), src->size());
});
_pool.barrier();
}

void FallbackStream::visit(const BufferDownloadCommand *command) noexcept {
_pool.async([cmd = *command] {
auto src = reinterpret_cast<const void *>(cmd.handle() + cmd.offset());
_pool.async([cmd = *command, buffer = command->handle(), offset = command->offset()] {
auto src = reinterpret_cast<FallbackBuffer*>(buffer)->view(offset).ptr;
std::memcpy(cmd.data(), src, cmd.size());
});
_pool.barrier();
}

void FallbackStream::visit(const BufferCopyCommand *command) noexcept {
_pool.async([cmd = *command] {
auto src = reinterpret_cast<const void *>(cmd.src_handle() + cmd.src_offset());
auto dst = reinterpret_cast<void *>(cmd.dst_handle() + cmd.dst_offset());
auto src = reinterpret_cast<FallbackBuffer*>(cmd.src_handle())->view(cmd.src_offset()).ptr;
auto dst = reinterpret_cast<FallbackBuffer*>(cmd.dst_handle())->view(cmd.dst_offset()).ptr;
std::memcpy(dst, src, cmd.size());
});
_pool.barrier();
Expand Down

0 comments on commit b9a9599

Please sign in to comment.