Skip to content

Commit

Permalink
Merge branch 'dev-fallback-backend' of https://github.com/LuisaGroup/…
Browse files Browse the repository at this point in the history
…LuisaCompute into dev-fallback-backend
  • Loading branch information
Mike-Leo-Smith committed Jan 13, 2025
2 parents ef2ef05 + 66e3e52 commit 78f436e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 115 deletions.
116 changes: 46 additions & 70 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,77 +765,17 @@ class FallbackCodegen {
}

[[nodiscard]] llvm::Value *_translate_binary_rotate_left(CurrentFunction &current, IRBuilder &b, const xir::Value *value, const xir::Value *shift) noexcept {
LUISA_ASSERT(value->type() == shift->type(), "Type mismatch for rotate left.");
auto llvm_value = _lookup_value(current, b, value);
auto llvm_shift = _lookup_value(current, b, shift);
auto value_type = value->type();
auto elem_type = value_type->is_vector() ? value_type->element() : value_type;
LUISA_ASSERT(value_type != nullptr, "Operand type is null.");
LUISA_ASSERT(value_type == shift->type(), "Type mismatch for rotate left.");
LUISA_ASSERT(value_type->is_scalar() || value_type->is_vector(), "Invalid operand type.");
auto bit_width = 0u;
switch (elem_type->tag()) {
case Type::Tag::INT8: [[fallthrough]];
case Type::Tag::UINT8: bit_width = 8; break;
case Type::Tag::INT16: [[fallthrough]];
case Type::Tag::UINT16: bit_width = 16; break;
case Type::Tag::INT32: [[fallthrough]];
case Type::Tag::UINT32: bit_width = 32; break;
case Type::Tag::INT64: [[fallthrough]];
case Type::Tag::UINT64: bit_width = 64; break;
default: LUISA_ERROR_WITH_LOCATION(
"Invalid operand type for rotate left operation: {}.",
elem_type->description());
}
auto llvm_elem_type = _translate_type(elem_type, true);
auto llvm_bit_width = llvm::ConstantInt::get(llvm_elem_type, bit_width);
if (value_type->is_vector()) {
llvm_bit_width = llvm::ConstantVector::getSplat(
llvm::ElementCount::getFixed(value_type->dimension()),
llvm_bit_width);
}
auto shifted_left = b.CreateShl(llvm_value, llvm_shift);
auto complement_shift = b.CreateSub(llvm_bit_width, llvm_shift);
auto shifted_right = b.CreateLShr(llvm_value, complement_shift);
return b.CreateOr(shifted_left, shifted_right);
return b.CreateIntrinsic(llvm_value->getType(), llvm::Intrinsic::fshl, {llvm_value, llvm_value, llvm_shift});
}

[[nodiscard]] llvm::Value *_translate_binary_rotate_right(CurrentFunction &current, IRBuilder &b, const xir::Value *value, const xir::Value *shift) noexcept {
// Lookup LLVM values for operands
LUISA_ASSERT(value->type() == shift->type(), "Type mismatch for rotate right.");
auto llvm_value = _lookup_value(current, b, value);
auto llvm_shift = _lookup_value(current, b, shift);
auto value_type = value->type();
auto elem_type = value_type->is_vector() ? value_type->element() : value_type;

// Type and null checks
LUISA_ASSERT(value_type != nullptr, "Operand type is null.");
LUISA_ASSERT(value_type == shift->type(), "Type mismatch for rotate right.");
LUISA_ASSERT(value_type->is_scalar() || value_type->is_vector(), "Invalid operand type.");

auto bit_width = 0u;
switch (elem_type->tag()) {
case Type::Tag::INT8: [[fallthrough]];
case Type::Tag::UINT8: bit_width = 8; break;
case Type::Tag::INT16: [[fallthrough]];
case Type::Tag::UINT16: bit_width = 16; break;
case Type::Tag::INT32: [[fallthrough]];
case Type::Tag::UINT32: bit_width = 32; break;
case Type::Tag::INT64: [[fallthrough]];
case Type::Tag::UINT64: bit_width = 64; break;
default: LUISA_ERROR_WITH_LOCATION(
"Invalid operand type for rotate right operation: {}.",
elem_type->description());
}
auto llvm_elem_type = _translate_type(elem_type, true);
auto llvm_bit_width = llvm::ConstantInt::get(llvm_elem_type, bit_width);
if (value_type->is_vector()) {
llvm_bit_width = llvm::ConstantVector::getSplat(
llvm::ElementCount::getFixed(value_type->dimension()),
llvm_bit_width);
}
auto shifted_right = b.CreateLShr(llvm_value, llvm_shift);
auto complement_shift = b.CreateSub(llvm_bit_width, llvm_shift);
auto shifted_left = b.CreateShl(llvm_value, complement_shift);
return b.CreateOr(shifted_left, shifted_right);
return b.CreateIntrinsic(llvm_value->getType(), llvm::Intrinsic::fshr, {llvm_value, llvm_value, llvm_shift});
}

[[nodiscard]] llvm::Value *_translate_binary_less(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
Expand Down Expand Up @@ -1200,7 +1140,7 @@ class FallbackCodegen {
case Type::Tag::FLOAT32: [[fallthrough]];
case Type::Tag::FLOAT64: {
auto llvm_elem_type = _translate_type(operand_elem_type, false);
auto llvm_zero = llvm::ConstantFP::get(llvm_elem_type, 0.);
auto llvm_zero = llvm::ConstantFP::getNegativeZero(llvm_elem_type);
auto llvm_one = llvm::ConstantFP::get(llvm_elem_type, 1.);
switch (op) {
case xir::ArithmeticOp::REDUCE_SUM: return b.CreateFAddReduce(llvm_zero, llvm_operand);
Expand All @@ -1224,7 +1164,7 @@ class FallbackCodegen {
auto llvm_elem_type = llvm_mul->getType()->isVectorTy() ?
llvm::cast<llvm::VectorType>(llvm_mul->getType())->getElementType() :
llvm_mul->getType();
auto llvm_zero = llvm::ConstantFP::get(llvm_elem_type, 0.);
auto llvm_zero = llvm::ConstantFP::getNegativeZero(llvm_elem_type);
return b.CreateFAddReduce(llvm_zero, llvm_mul);
}
return b.CreateAddReduce(llvm_mul);
Expand Down Expand Up @@ -2086,7 +2026,13 @@ class FallbackCodegen {
}
case xir::ArithmeticOp::ISINF: return _translate_isinf_isnan(current, b, inst->op(), inst->operand(0u));
case xir::ArithmeticOp::ISNAN: return _translate_isinf_isnan(current, b, inst->op(), inst->operand(0u));
case xir::ArithmeticOp::ACOS: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), "acos");
case xir::ArithmeticOp::ACOS: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::acos);
#else
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), "acos");
#endif
}
case xir::ArithmeticOp::ACOSH: {
// acosh(x) = log(x + sqrt(x^2 - 1))
auto llvm_x = _lookup_value(current, b, inst->operand(0u));
Expand All @@ -2097,7 +2043,13 @@ class FallbackCodegen {
auto llvm_x_plus_sqrt = b.CreateFAdd(llvm_x, llvm_sqrt);
return b.CreateUnaryIntrinsic(llvm::Intrinsic::log, llvm_x_plus_sqrt);
}
case xir::ArithmeticOp::ASIN: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), "asin");
case xir::ArithmeticOp::ASIN: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::asin);
#else
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), "asin");
#endif
}
case xir::ArithmeticOp::ASINH: {
// asinh(x) = log(x + sqrt(x^2 + 1))
auto llvm_x = _lookup_value(current, b, inst->operand(0u));
Expand All @@ -2108,8 +2060,16 @@ class FallbackCodegen {
auto llvm_x_plus_sqrt = b.CreateFAdd(llvm_x, llvm_sqrt);
return b.CreateUnaryIntrinsic(llvm::Intrinsic::log, llvm_x_plus_sqrt);
}
case xir::ArithmeticOp::ATAN: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), "atan");
case xir::ArithmeticOp::ATAN2: return _translate_binary_fp_math_operation(current, b, inst->operand(0), inst->operand(1), "atan2");
case xir::ArithmeticOp::ATAN: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::atan);
#else
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), "atan");
#endif
}
case xir::ArithmeticOp::ATAN2: {
return _translate_binary_fp_math_operation(current, b, inst->operand(0), inst->operand(1), "atan2");
}
case xir::ArithmeticOp::ATANH: {
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
auto llvm_x = _lookup_value(current, b, inst->operand(0u));
Expand All @@ -2123,6 +2083,9 @@ class FallbackCodegen {
}
case xir::ArithmeticOp::COS: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::cos);
case xir::ArithmeticOp::COSH: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::cosh);
#else
// cosh(x) = 0.5 * (exp(x) + exp(-x))
auto llvm_x = _lookup_value(current, b, inst->operand(0u));
auto llvm_exp_x = b.CreateUnaryIntrinsic(llvm::Intrinsic::exp, llvm_x);
Expand All @@ -2131,9 +2094,13 @@ class FallbackCodegen {
auto llvm_exp_sum = b.CreateFAdd(llvm_exp_x, llvm_exp_neg_x);
auto llvm_half = llvm::ConstantFP::get(llvm_x->getType(), .5f);
return b.CreateFMul(llvm_half, llvm_exp_sum);
#endif
}
case xir::ArithmeticOp::SIN: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::sin);
case xir::ArithmeticOp::SINH: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::sinh);
#else
// sinh(x) = 0.5 * (exp(x) - exp(-x))
auto llvm_x = _lookup_value(current, b, inst->operand(0u));
auto llvm_exp_x = b.CreateUnaryIntrinsic(llvm::Intrinsic::exp, llvm_x);
Expand All @@ -2142,20 +2109,29 @@ class FallbackCodegen {
auto llvm_exp_diff = b.CreateFSub(llvm_exp_x, llvm_exp_neg_x);
auto llvm_half = llvm::ConstantFP::get(llvm_x->getType(), .5f);
return b.CreateFMul(llvm_half, llvm_exp_diff);
#endif
}
case xir::ArithmeticOp::TAN: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::tan);
#else
auto llvm_sin = _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::sin);
auto llvm_cos = _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::cos);
return b.CreateFDiv(llvm_sin, llvm_cos);
#endif
}
case xir::ArithmeticOp::TANH: {
#if LLVM_VERSION_MAJOR >= 19
return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::tanh);
#else
// tanh(x) = sinh(x) / cosh(x) = (exp(2x) - 1) / (exp(2x) + 1)
auto llvm_x = _lookup_value(current, b, inst->operand(0u));
auto llvm_two_x = b.CreateFMul(llvm_x, llvm::ConstantFP::get(llvm_x->getType(), 2.f));
auto llvm_exp_2x = b.CreateUnaryIntrinsic(llvm::Intrinsic::exp, llvm_two_x);
auto llvm_exp_2x_minus_one = b.CreateFSub(llvm_exp_2x, llvm::ConstantFP::get(llvm_x->getType(), 1.f));
auto llvm_exp_2x_plus_one = b.CreateFAdd(llvm_exp_2x, llvm::ConstantFP::get(llvm_x->getType(), 1.f));
return b.CreateFDiv(llvm_exp_2x_minus_one, llvm_exp_2x_plus_one);
#endif
}
case xir::ArithmeticOp::EXP: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::exp);
case xir::ArithmeticOp::EXP2: return _translate_unary_fp_math_operation(current, b, inst->operand(0u), llvm::Intrinsic::exp2);
Expand Down
90 changes: 45 additions & 45 deletions src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ static const bool LUISA_SHOULD_DUMP_XIR = [] {
return false;
}();

static const bool LUISA_SHOULD_DUMP_LLVM_IR = [] {
if (auto env = getenv("LUISA_DUMP_LLVM_IR")) {
return std::string_view{env} == "1";
}
return false;
}();

static const bool LUISA_SHOULD_DUMP_ASM = [] {
if (auto env = getenv("LUISA_DUMP_ASM")) {
return std::string_view{env} == "1";
}
return false;
}();

namespace luisa::compute::fallback {

[[nodiscard]] static luisa::half luisa_fallback_asin_f16(luisa::half x) noexcept { return ::half_float::asin(x); }
Expand Down Expand Up @@ -209,18 +223,9 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio
luisa::string_view{parse_error.getMessage()});
}
auto codegen_feedback = luisa_fallback_backend_codegen(*llvm_ctx, llvm_module.get(), xir_module);
//llvm_module->print(llvm::errs(), nullptr, true, true);
//llvm_module->print(llvm::outs(), nullptr, true, true);
if (llvm::verifyModule(*llvm_module, &llvm::errs())) {
LUISA_ERROR_WITH_LOCATION("LLVM module verification failed.");
}
// {
// llvm_module->print(llvm::errs(), nullptr, true, true);
// // std::error_code EC;
// // llvm::raw_fd_ostream file_stream("H:/abc.ll", EC, llvm::sys::fs::OF_None);
// // llvm_module->print(file_stream, nullptr, true, true);
// // file_stream.close();
// }

// map symbols
llvm::orc::SymbolMap symbol_map{};
Expand Down Expand Up @@ -276,7 +281,6 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio
LUISA_ERROR_WITH_LOCATION("Failed to define symbols.");
}

// optimize
llvm_module->setDataLayout(_target_machine->createDataLayout());
llvm_module->setTargetTriple(_target_machine->getTargetTriple().str());

Expand All @@ -291,7 +295,18 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio
}
}

// optimize with the new pass manager
if (LUISA_SHOULD_DUMP_LLVM_IR) {
auto filename = luisa::format("kernel.{:016x}.ll", kernel.hash());
std::error_code ec;
llvm::raw_fd_ostream ofs{llvm::StringRef{filename}, ec};
if (ec) {
LUISA_WARNING_WITH_LOCATION("Failed to open file for dumping LLVM IR: {}.", ec.message());
} else {
llvm_module->print(ofs, nullptr, false, true);
}
}

// optimize
::llvm::LoopAnalysisManager LAM;
::llvm::FunctionAnalysisManager FAM;
::llvm::CGSCCAnalysisManager CGAM;
Expand Down Expand Up @@ -323,45 +338,30 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio
if (::llvm::verifyModule(*llvm_module, &::llvm::errs())) {
LUISA_ERROR_WITH_LOCATION("Failed to verify module.");
}
// {
// std::error_code EC;
// llvm::raw_fd_ostream file_stream("bbc.ll", EC, llvm::sys::fs::OF_None);
// llvm_module->print(file_stream, nullptr, true, true);
// file_stream.close();
// }
// LUISA_INFO("Printing optimized LLVM module...");
// llvm_module->print(llvm::outs(), nullptr, false, true);

// print x64 assembly of llvm_module
if constexpr (false) {
auto asm_name = "kernel_" + std::to_string(kernel.hash()) + ".s";
{
std::error_code EC;
llvm::raw_fd_ostream dest(asm_name, EC, llvm::sys::fs::OF_None);
llvm::legacy::PassManager pass;

if (EC) {
LUISA_ERROR_WITH_LOCATION("Could not open file: {}", EC.message());
}
if (LUISA_SHOULD_DUMP_LLVM_IR) {
auto filename = luisa::format("kernel.{:016x}.opt.ll", kernel.hash());
std::error_code ec;
llvm::raw_fd_ostream ofs{llvm::StringRef{filename}, ec};
if (ec) {
LUISA_WARNING_WITH_LOCATION("Failed to open file for dumping optimized LLVM IR: {}.", ec.message());
} else {
llvm_module->print(ofs, nullptr, false, true);
}
}

if (_target_machine->addPassesToEmitFile(pass, dest, nullptr, llvm::CodeGenFileType::AssemblyFile)) {
if (LUISA_SHOULD_DUMP_ASM) {
auto asm_name = luisa::format("kernel.{:016x}.s", kernel.hash());
std::error_code ec;
llvm::raw_fd_ostream ofs{llvm::StringRef{asm_name}, ec};
if (ec) {
LUISA_WARNING_WITH_LOCATION("Failed to open file for dumping assembly: {}.", ec.message());
} else {
llvm::legacy::PassManager pass;
if (_target_machine->addPassesToEmitFile(pass, ofs, nullptr, llvm::CodeGenFileType::AssemblyFile)) {
LUISA_ERROR_WITH_LOCATION("TheTargetMachine can't emit a file of this type");
}
pass.run(*llvm_module);
dest.flush();
}

std::ifstream asm_file(asm_name);
if (asm_file.is_open()) {
std::stringstream buffer;
buffer << asm_file.rdbuf();
LUISA_INFO("Kernel Assembly:\n{}", buffer.str());
asm_file.close();
if (std::remove(asm_name.c_str()) != 0) {
LUISA_WARNING_WITH_LOCATION("Failed to delete assembly file: {}", asm_name);
}
} else {
LUISA_ERROR_WITH_LOCATION("Failed to open assembly file: {}", asm_name);
}
}

Expand Down

0 comments on commit 78f436e

Please sign in to comment.