Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Dec 11, 2024
1 parent d707c97 commit 960a922
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 31 deletions.
1 change: 0 additions & 1 deletion include/luisa/runtime/rhi/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ struct SwapchainCreationInfo : public ResourceCreationInfo {
};

struct ShaderCreationInfo : public ResourceCreationInfo {
// luisa::string name;
uint3 block_size;

[[nodiscard]] static auto make_invalid() noexcept {
Expand Down
13 changes: 9 additions & 4 deletions src/backends/fallback/fallback_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <luisa/core/stl.h>
#include <luisa/core/logging.h>
#include <luisa/core/clock.h>
#include "fallback_stream.h"
#include "fallback_device.h"
#include "fallback_texture.h"
Expand Down Expand Up @@ -182,10 +183,14 @@ SwapchainCreationInfo FallbackDevice::create_swapchain(const SwapchainOption &op
}

ShaderCreationInfo FallbackDevice::create_shader(const ShaderOption &option, Function kernel) noexcept {
return ShaderCreationInfo{
ResourceCreationInfo{
.handle = reinterpret_cast<uint64_t>(luisa::new_with_allocator<FallbackShader>(option, kernel))}};
return ShaderCreationInfo();
Clock clk;
auto shader = luisa::new_with_allocator<FallbackShader>(option, kernel);
LUISA_VERBOSE("Shader compilation took {} ms.", clk.toc());
ShaderCreationInfo info{};
info.handle = reinterpret_cast<uint64_t>(shader);
info.native_handle = shader->native_handle();
info.block_size = kernel.block_size();
return info;
}

ShaderCreationInfo FallbackDevice::create_shader(const ShaderOption &option, const ir::KernelModule *kernel) noexcept {
Expand Down
58 changes: 33 additions & 25 deletions src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,25 @@

namespace luisa::compute::fallback {

[[nodiscard]] static luisa::half luisa_asin_f16(luisa::half x) noexcept { return ::half_float::asin(x); }
[[nodiscard]] static float luisa_asin_f32(float x) noexcept { return std::asin(x); }
[[nodiscard]] static double luisa_asin_f64(double x) noexcept { return std::asin(x); }
[[nodiscard]] static luisa::half luisa_fallback_asin_f16(luisa::half x) noexcept { return ::half_float::asin(x); }
[[nodiscard]] static float luisa_fallback_asin_f32(float x) noexcept { return std::asin(x); }
[[nodiscard]] static double luisa_fallback_asin_f64(double x) noexcept { return std::asin(x); }

[[nodiscard]] static luisa::half luisa_acos_f16(luisa::half x) noexcept { return ::half_float::acos(x); }
[[nodiscard]] static float luisa_acos_f32(float x) noexcept { return std::acos(x); }
[[nodiscard]] static double luisa_acos_f64(double x) noexcept { return std::acos(x); }
[[nodiscard]] static luisa::half luisa_fallback_acos_f16(luisa::half x) noexcept { return ::half_float::acos(x); }
[[nodiscard]] static float luisa_fallback_acos_f32(float x) noexcept { return std::acos(x); }
[[nodiscard]] static double luisa_fallback_acos_f64(double x) noexcept { return std::acos(x); }

[[nodiscard]] static luisa::half luisa_atan_f16(luisa::half x) noexcept { return ::half_float::atan(x); }
[[nodiscard]] static float luisa_atan_f32(float x) noexcept { return std::atan(x); }
[[nodiscard]] static double luisa_atan_f64(double x) noexcept { return std::atan(x); }
[[nodiscard]] static luisa::half luisa_fallback_atan_f16(luisa::half x) noexcept { return ::half_float::atan(x); }
[[nodiscard]] static float luisa_fallback_atan_f32(float x) noexcept { return std::atan(x); }
[[nodiscard]] static double luisa_fallback_atan_f64(double x) noexcept { return std::atan(x); }

[[nodiscard]] static luisa::half luisa_atan2_f16(luisa::half a, luisa::half b) noexcept { return ::half_float::atan2(a, b); }
[[nodiscard]] static float luisa_atan2_f32(float a, float b) noexcept { return std::atan2(a, b); }
[[nodiscard]] static double luisa_atan2_f64(double a, double b) noexcept { return std::atan2(a, b); }
[[nodiscard]] static luisa::half luisa_fallback_atan2_f16(luisa::half a, luisa::half b) noexcept { return ::half_float::atan2(a, b); }
[[nodiscard]] static float luisa_fallback_atan2_f32(float a, float b) noexcept { return std::atan2(a, b); }
[[nodiscard]] static double luisa_fallback_atan2_f64(double a, double b) noexcept { return std::atan2(a, b); }

static void luisa_fallback_assert(bool condition, const char *message) noexcept {
if (!condition) { LUISA_ERROR_WITH_LOCATION("Assertion failed: {}.", message); }
}

struct FallbackShaderLaunchConfig {
uint3 block_id;
Expand All @@ -68,6 +72,7 @@ FallbackShader::FallbackShader(const ShaderOption &option, Function kernel) noex
options.ApproxFuncFPMath = true;
options.EnableIPRA = true;
options.StackSymbolOrdering = true;
options.TrapUnreachable = false;
options.EnableMachineFunctionSplitter = true;
options.EnableMachineOutliner = true;
options.NoTrapAfterNoreturn = true;
Expand Down Expand Up @@ -118,18 +123,21 @@ FallbackShader::FallbackShader(const ShaderOption &option, Function kernel) noex
#include "fallback_device_api_map_symbols.inl.h"

// asin, acos, atan, atan2
map_symbol("luisa.asin.f16", &luisa_asin_f16);
map_symbol("luisa.asin.f32", &luisa_asin_f32);
map_symbol("luisa.asin.f64", &luisa_asin_f64);
map_symbol("luisa.acos.f16", &luisa_acos_f16);
map_symbol("luisa.acos.f32", &luisa_acos_f32);
map_symbol("luisa.acos.f64", &luisa_acos_f64);
map_symbol("luisa.atan.f16", &luisa_atan_f16);
map_symbol("luisa.atan.f32", &luisa_atan_f32);
map_symbol("luisa.atan.f64", &luisa_atan_f64);
map_symbol("luisa.atan2.f16", &luisa_atan2_f16);
map_symbol("luisa.atan2.f32", &luisa_atan2_f32);
map_symbol("luisa.atan2.f64", &luisa_atan2_f64);
map_symbol("luisa.asin.f16", &luisa_fallback_asin_f16);
map_symbol("luisa.asin.f32", &luisa_fallback_asin_f32);
map_symbol("luisa.asin.f64", &luisa_fallback_asin_f64);
map_symbol("luisa.acos.f16", &luisa_fallback_acos_f16);
map_symbol("luisa.acos.f32", &luisa_fallback_acos_f32);
map_symbol("luisa.acos.f64", &luisa_fallback_acos_f64);
map_symbol("luisa.atan.f16", &luisa_fallback_atan_f16);
map_symbol("luisa.atan.f32", &luisa_fallback_atan_f32);
map_symbol("luisa.atan.f64", &luisa_fallback_atan_f64);
map_symbol("luisa.atan2.f16", &luisa_fallback_atan2_f16);
map_symbol("luisa.atan2.f32", &luisa_fallback_atan2_f32);
map_symbol("luisa.atan2.f64", &luisa_fallback_atan2_f64);

// assert
map_symbol("luisa.assert", &luisa_fallback_assert);

if (auto error = _jit->getMainJITDylib().define(
::llvm::orc::absoluteSymbols(std::move(symbol_map)))) {
Expand Down Expand Up @@ -161,7 +169,7 @@ FallbackShader::FallbackShader(const ShaderOption &option, Function kernel) noex
auto llvm_ctx = std::make_unique<llvm::LLVMContext>();
auto builtin_module = fallback_backend_device_builtin_module();
llvm::SMDiagnostic parse_error;
auto llvm_module = llvm::parseIR(llvm::MemoryBufferRef{builtin_module, "module"}, parse_error, *llvm_ctx);
auto llvm_module = llvm::parseIR(llvm::MemoryBufferRef{builtin_module, ""}, parse_error, *llvm_ctx);
if (!llvm_module) {
LUISA_ERROR_WITH_LOCATION("Failed to generate LLVM IR: {}.",
luisa::string_view{parse_error.getMessage()});
Expand Down
2 changes: 1 addition & 1 deletion src/backends/fallback/fallback_shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class FallbackShader {
[[nodiscard]] auto argument_buffer_size() const noexcept { return _argument_buffer_size; }
[[nodiscard]] auto shared_memory_size() const noexcept { return _shared_memory_size; }
[[nodiscard]] size_t argument_offset(uint uid) const noexcept;
//[[nodiscard]] auto callbacks() const noexcept { return _callbacks.data(); }
[[nodiscard]] auto native_handle() const noexcept { return _kernel_entry; }
};

}// namespace luisa::compute::fallback

0 comments on commit 960a922

Please sign in to comment.