Skip to content

Commit

Permalink
fallback: buffer related fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
swfly committed Nov 28, 2024
1 parent 56d52d9 commit f05d9ac
Showing 7 changed files with 87 additions and 87 deletions.
122 changes: 60 additions & 62 deletions src/backends/fallback/fallback_accel.cpp
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
//

#include <luisa/core/stl.h>
#include <luisa/core/logging.h>

#include "fallback_mesh.h"
#include "fallback_accel.h"
@@ -93,7 +94,7 @@ namespace luisa::compute::fallback
auto error = rtcGetDeviceError(_device);
if (error != RTC_ERROR_NONE)
{
printf("Embree Error: %d\n", error);
LUISA_INFO("RTC ERROR: {}", (uint)error);
}
});
}
@@ -116,99 +117,96 @@ namespace luisa::compute::fallback
m[3], m[7], m[11], 1.f);
}

namespace detail
{
void accel_trace_closest(const FallbackAccel* accel, float ox, float oy, float oz, float dx, float dy, float dz,
float tmin, float tmax, uint mask, SurfaceHit* hit) noexcept
{
void accel_trace_closest(const FallbackAccel* accel, float ox, float oy, float oz, float dx, float dy, float dz,
float tmin, float tmax, uint mask, SurfaceHit* hit) noexcept
{
#if LUISA_COMPUTE_EMBREE_VERSION == 3
RTCIntersectContext ctx{};
rtcInitIntersectContext(&ctx);
RTCIntersectContext ctx{};
rtcInitIntersectContext(&ctx);
#else
RTCRayQueryContext ctx{};
RTCRayQueryContext ctx{};
rtcInitRayQueryContext(&ctx);
RTCIntersectArguments args{.context = &ctx};
#endif
RTCRayHit rh{};
rh.ray.org_x = ox;
rh.ray.org_y = oy;
rh.ray.org_z = oz;
rh.ray.dir_x = dx;
rh.ray.dir_y = dy;
rh.ray.dir_z = dz;
rh.ray.tnear = tmin;
rh.ray.tfar = tmax;
RTCRayHit rh{};
rh.ray.org_x = ox;
rh.ray.org_y = oy;
rh.ray.org_z = oz;
rh.ray.dir_x = dx;
rh.ray.dir_y = dy;
rh.ray.dir_z = dz;
rh.ray.tnear = tmin;
rh.ray.tfar = tmax;

rh.ray.mask = mask;
rh.hit.geomID = RTC_INVALID_GEOMETRY_ID;
rh.hit.primID = RTC_INVALID_GEOMETRY_ID;
rh.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
rh.ray.flags = 0;
rh.ray.mask = mask;
rh.hit.geomID = RTC_INVALID_GEOMETRY_ID;
rh.hit.primID = RTC_INVALID_GEOMETRY_ID;
rh.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
rh.ray.flags = 0;
#if LUISA_COMPUTE_EMBREE_VERSION == 3
rtcIntersect1(accel->scene(), &ctx, &rh);
rtcIntersect1(accel->scene(), &ctx, &rh);
#else
rtcIntersect1(accel->scene(), &rh, &args);
rtcIntersect1(accel->scene(), &rh, &args);
#endif
hit->inst = rh.hit.instID[0];
hit->prim = rh.hit.primID;
hit->bary = make_float2(rh.hit.u, rh.hit.v);
hit->committed_ray_t = rh.ray.tfar;
}
void fill_transform(const FallbackAccel* accel, uint id, float4x4* buffer)
{
// TODO: handle embree 4
hit->inst = rh.hit.instID[0];
hit->prim = rh.hit.primID;
hit->bary = make_float2(rh.hit.u, rh.hit.v);
hit->committed_ray_t = rh.ray.tfar;
}
void fill_transform(const FallbackAccel* accel, uint id, float4x4* buffer)
{
// TODO: handle embree 4

// Retrieve the RTCInstance (you may need to store instances in your application)
auto instance = rtcGetGeometry(accel->scene(), id);
// Retrieve the RTCInstance (you may need to store instances in your application)
auto instance = rtcGetGeometry(accel->scene(), id);

// Get the transform of the instance (a 4x4 matrix)
rtcGetGeometryTransform(instance, 0.f, RTCFormat::RTC_FORMAT_FLOAT4X4_COLUMN_MAJOR, buffer);
}
// Get the transform of the instance (a 4x4 matrix)
rtcGetGeometryTransform(instance, 0.f, RTCFormat::RTC_FORMAT_FLOAT4X4_COLUMN_MAJOR, buffer);
}

bool accel_trace_any(const FallbackAccel* accel, float ox, float oy, float oz, float dx, float dy, float dz,
float tmin, float tmax, uint mask) noexcept
{
bool accel_trace_any(const FallbackAccel* accel, float ox, float oy, float oz, float dx, float dy, float dz,
float tmin, float tmax, uint mask) noexcept
{
#if LUISA_COMPUTE_EMBREE_VERSION == 3
RTCIntersectContext ctx{};
rtcInitIntersectContext(&ctx);
RTCIntersectContext ctx{};
rtcInitIntersectContext(&ctx);
#else
RTCRayQueryContext ctx{};
RTCRayQueryContext ctx{};
rtcInitRayQueryContext(&ctx);
RTCOccludedArguments args{.context = &ctx};
#endif
RTCRay ray{};
ray.org_x = ox;
ray.org_y = oy;
ray.org_z = oz;
ray.dir_x = dx;
ray.dir_y = dy;
ray.dir_z = dz;
ray.tnear = tmin;
ray.tfar = tmax;
RTCRay ray{};
ray.org_x = ox;
ray.org_y = oy;
ray.org_z = oz;
ray.dir_x = dx;
ray.dir_y = dy;
ray.dir_z = dz;
ray.tnear = tmin;
ray.tfar = tmax;

ray.mask = mask;
ray.flags = 0;
ray.mask = mask;
ray.flags = 0;
#if LUISA_COMPUTE_EMBREE_VERSION == 3
rtcOccluded1(accel->scene(), &ctx, &ray);
rtcOccluded1(accel->scene(), &ctx, &ray);
#else
rtcOccluded1(accel->scene(), &ray, &args);
rtcOccluded1(accel->scene(), &ray, &args);
#endif
return ray.tfar < 0.f;
}
} // namespace detail
return ray.tfar < 0.f;
}
} // namespace luisa::compute::fallback

void intersect_closest_wrapper(void* accel, float ox, float oy, float oz, float dx, float dy, float dz, float tmin,
float tmax, unsigned mask, void* hit)
{
luisa::compute::fallback::detail::accel_trace_closest(
luisa::compute::fallback::accel_trace_closest(
reinterpret_cast<luisa::compute::fallback::FallbackAccel *>(accel),
ox, oy, oz, dx, dy, dz, tmin, tmax, mask, reinterpret_cast<luisa::compute::SurfaceHit *>(hit));
}

void accel_transform_wrapper(void* accel, unsigned id, void* buffer)
{
luisa::compute::fallback::detail::fill_transform(
luisa::compute::fallback::fill_transform(
reinterpret_cast<luisa::compute::fallback::FallbackAccel *>(accel),
id, reinterpret_cast<luisa::float4x4 *>(buffer));
}
1 change: 1 addition & 0 deletions src/backends/fallback/fallback_accel.h
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ class FallbackAccel {
[[nodiscard]] static float4x4 _decompress(std::array<float, 12> m) noexcept;

public:
[[nodiscard]]auto device()const noexcept{return _device;}
[[nodiscard]] RTCScene scene()const noexcept {return _handle;}
FallbackAccel(RTCDevice device, AccelUsageHint hint) noexcept;
~FallbackAccel() noexcept;
5 changes: 2 additions & 3 deletions src/backends/fallback/fallback_bindless_array.cpp
Original file line number Diff line number Diff line change
@@ -3,9 +3,8 @@
//

#include "fallback_bindless_array.h"
#include "fallback_buffer.h"
#include "thread_pool.h"
#include "luisa/runtime/rtx/triangle.h"
#include "luisa/rust/api_types.hpp"

namespace luisa::compute::fallback
{
@@ -36,6 +35,6 @@ namespace luisa::compute::fallback
void bindless_buffer_read(void* bindless, size_t slot, size_t elem, unsigned stride, void* buffer)
{
auto a = reinterpret_cast<luisa::compute::fallback::FallbackBindlessArray *>(bindless);
auto ptr = reinterpret_cast<const std::byte*>(a->slot(slot).buffer);
auto ptr = reinterpret_cast<const luisa::compute::fallback::FallbackBuffer*>(a->slot(slot).buffer)->addr();
std::memcpy(buffer, ptr + elem*stride, stride);
}
2 changes: 1 addition & 1 deletion src/backends/fallback/fallback_buffer.h
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ struct alignas(16) FallbackBufferView {
class FallbackBuffer {
public:
explicit FallbackBuffer(size_t size, unsigned elementStride);
void *addr() { return data; }
std::byte *addr()const noexcept { return data; }
[[nodiscard]] FallbackBufferView view(size_t offset) noexcept;
~FallbackBuffer();
private:
1 change: 1 addition & 0 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
@@ -1707,6 +1707,7 @@ class FallbackCodegen {
auto hit_type = _translate_type(Type::of<luisa::compute::SurfaceHit>(), false);

auto hit_alloca = b.CreateAlloca(hit_type, nullptr, "");
hit_alloca->setAlignment(llvm::Align(8u));

// Extract ray components
auto compressed_origin = b.CreateExtractValue(llvm_ray, 0, "");
37 changes: 19 additions & 18 deletions src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
@@ -127,15 +127,15 @@ luisa::compute::fallback::FallbackShader::FallbackShader(const luisa::compute::S
auto xir_module = xir::ast_to_xir_translate(kernel, {});
xir_module->set_name(luisa::format("kernel_{:016x}", kernel.hash()));
if (!option.name.empty()) { xir_module->set_location(option.name); }
LUISA_INFO("Kernel XIR:\n{}", xir::xir_to_text_translate(xir_module, true));
//LUISA_INFO("Kernel XIR:\n{}", xir::xir_to_text_translate(xir_module, true));

auto llvm_ctx = std::make_unique<llvm::LLVMContext>();
auto llvm_module = luisa_fallback_backend_codegen(*llvm_ctx, xir_module);
if (!llvm_module) {
LUISA_ERROR_WITH_LOCATION("Failed to generate LLVM IR.");
}
//llvm_module->print(llvm::errs(), nullptr, true, true);
llvm_module->print(llvm::outs(), nullptr, true, true);
//llvm_module->print(llvm::outs(), nullptr, true, true);
if (llvm::verifyModule(*llvm_module, &llvm::errs())) {
LUISA_ERROR_WITH_LOCATION("LLVM module verification failed.");
}
@@ -175,7 +175,7 @@ luisa::compute::fallback::FallbackShader::FallbackShader(const luisa::compute::S
if (::llvm::verifyModule(*llvm_module, &::llvm::errs())) {
LUISA_ERROR_WITH_LOCATION("Failed to verify module.");
}
llvm_module->print(llvm::outs(), nullptr, true, true);
//llvm_module->print(llvm::outs(), nullptr, true, true);

// compile to machine code
auto m = llvm::orc::ThreadSafeModule(std::move(llvm_module), std::move(llvm_ctx));
@@ -272,22 +272,23 @@ void compute::fallback::FallbackShader::dispatch(ThreadPool &pool, const compute

auto data = argument_buffer.data();

for (int i = 0; i < dispatch_counts.x; ++i) {
for (int j = 0; j < dispatch_counts.y; ++j) {
for (int k = 0; k < dispatch_counts.z; ++k) {
auto c = config;
c.block_id = make_uint3(i, j, k);
(*_kernel_entry)(data, &c);
}
}
}
// for (int i = 0; i < dispatch_counts.x; ++i) {
// for (int j = 0; j < dispatch_counts.y; ++j) {
// for (int k = 0; k < dispatch_counts.z; ++k) {
// auto c = config;
// c.block_id = make_uint3(i, j, k);
// (*_kernel_entry)(data, &c);
// }
// }
// }

// pool.parallel(dispatch_counts.x, dispatch_counts.y, dispatch_counts.z,
// [this, config, data](auto bx, auto by, auto bz) noexcept {
// auto c = config;
// c.block_id = make_uint3(bx, by, bz);
// (*_kernel_entry)(data, &c);
// });
pool.parallel(dispatch_counts.x, dispatch_counts.y, dispatch_counts.z,
[this, config, data](auto bx, auto by, auto bz) noexcept {
auto c = config;
c.block_id = make_uint3(bx, by, bz);
(*_kernel_entry)(data, &c);
});
pool.synchronize();
// pool.barrier();
}
void compute::fallback::FallbackShader::build_bound_arguments(compute::Function kernel) {
6 changes: 3 additions & 3 deletions src/backends/fallback/fallback_stream.cpp
Original file line number Diff line number Diff line change
@@ -134,18 +134,18 @@ void FallbackStream::visit(const AccelBuildCommand *command) noexcept {
}

void FallbackStream::visit(const MeshBuildCommand *command) noexcept {
auto v_b = command->vertex_buffer();
auto v_b = reinterpret_cast<FallbackBuffer*>(command->vertex_buffer())->view(0).ptr;
auto v_b_o = command->vertex_buffer_offset();
auto v_s = command->vertex_stride();
auto v_b_s = command->vertex_buffer_size();
auto v_b_c = v_b_s/v_s;
auto t_b = command->triangle_buffer();
auto t_b = reinterpret_cast<FallbackBuffer*>(command->triangle_buffer())->view(0).ptr;
auto t_b_o = command->triangle_buffer_offset();
auto t_b_s = command->triangle_buffer_size();
auto t_b_c = t_b_s/12u;
_pool.async([=,mesh = reinterpret_cast<FallbackMesh *>(command->handle())]
{
mesh->commit(v_b, v_b_o, v_s, v_b_c, t_b, t_b_o, t_b_c);
mesh->commit(reinterpret_cast<uint64_t>(v_b), v_b_o, v_s, v_b_c, reinterpret_cast<uint64_t>(t_b), t_b_o, t_b_c);
});
_pool.barrier();
}

0 comments on commit f05d9ac

Please sign in to comment.