Skip to content

Commit

Permalink
spv type gen
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxwellGengYF committed Nov 27, 2024
1 parent d8436d6 commit e8afe30
Show file tree
Hide file tree
Showing 9 changed files with 447 additions and 0 deletions.
113 changes: 113 additions & 0 deletions src/backends/common/spirv_codegen/__pch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <luisa/ast/ast2json.h>
#include <luisa/ast/atomic_ref_node.h>
#include <luisa/ast/attribute.h>
#include <luisa/ast/callable_library.h>
#include <luisa/ast/constant_data.h>
#include <luisa/ast/expression.h>
#include <luisa/ast/external_function.h>
#include <luisa/ast/function.h>
#include <luisa/ast/function_builder.h>
#include <luisa/ast/interface.h>
#include <luisa/ast/op.h>
#include <luisa/ast/statement.h>
#include <luisa/ast/type.h>
#include <luisa/ast/type_registry.h>
#include <luisa/ast/usage.h>
#include <luisa/ast/variable.h>


#include <luisa/core/basic_traits.h>
#include <luisa/core/basic_types.h>
#include <luisa/core/binary_buffer.h>
#include <luisa/core/binary_file_stream.h>
#include <luisa/core/binary_io.h>
#include <luisa/core/clock.h>
#include <luisa/core/concepts.h>
#include <luisa/core/constants.h>
#include <luisa/core/dll_export.h>
#include <luisa/core/dynamic_module.h>
#include <luisa/core/fiber.h>
#include <luisa/core/first_fit.h>
#include <luisa/core/forget.h>
#include <luisa/core/intrin.h>
#include <luisa/core/logging.h>
#include <luisa/core/macro.h>
#include <luisa/core/magic_enum.h>
#include <luisa/core/mathematics.h>
#include <luisa/core/platform.h>
#include <luisa/core/pool.h>
#include <luisa/core/shared_function.h>
#include <luisa/core/spin_mutex.h>
#include <luisa/core/stl.h>
#include <luisa/core/string_scratch.h>
#include <luisa/core/thread_safety.h>


#include <luisa/vstl/allocate_type.h>
#include <luisa/vstl/arena_hash_map.h>
#include <luisa/vstl/common.h>
#include <luisa/vstl/compare.h>
#include <luisa/vstl/config.h>
#include <luisa/vstl/functional.h>
#include <luisa/vstl/hash.h>
#include <luisa/vstl/hash_map.h>
#include <luisa/vstl/lmdb.hpp>
#include <luisa/vstl/lockfree_array_queue.h>
#include <luisa/vstl/log.h>
#include <luisa/vstl/md5.h>
#include <luisa/vstl/memory.h>
#include <luisa/vstl/meta_lib.h>
#include <luisa/vstl/pool.h>
#include <luisa/vstl/ranges.h>
#include <luisa/vstl/spin_mutex.h>
#include <luisa/vstl/stack_allocator.h>
#include <luisa/vstl/string_hash.h>
#include <luisa/vstl/string_utility.h>
#include <luisa/vstl/tree_map_base.h>
#include <luisa/vstl/unique_ptr.h>
#include <luisa/vstl/v_allocator.h>
#include <luisa/vstl/v_guid.h>
#include <luisa/vstl/vector.h>
#include <luisa/vstl/vstring.h>

#include <luisa/xir/argument.h>
#include <luisa/xir/basic_block.h>
#include <luisa/xir/builder.h>
#include <luisa/xir/constant.h>
#include <luisa/xir/function.h>
#include <luisa/xir/ilist.h>
#include <luisa/xir/instruction.h>
#include <luisa/xir/instructions/alloca.h>
#include <luisa/xir/instructions/assert.h>
#include <luisa/xir/instructions/assume.h>
#include <luisa/xir/instructions/branch.h>
#include <luisa/xir/instructions/break.h>
#include <luisa/xir/instructions/call.h>
#include <luisa/xir/instructions/cast.h>
#include <luisa/xir/instructions/continue.h>
#include <luisa/xir/instructions/gep.h>
#include <luisa/xir/instructions/if.h>
#include <luisa/xir/instructions/intrinsic.h>
#include <luisa/xir/instructions/load.h>
#include <luisa/xir/instructions/loop.h>
#include <luisa/xir/instructions/outline.h>
#include <luisa/xir/instructions/phi.h>
#include <luisa/xir/instructions/print.h>
#include <luisa/xir/instructions/ray_query.h>
#include <luisa/xir/instructions/return.h>
#include <luisa/xir/instructions/store.h>
#include <luisa/xir/instructions/switch.h>
#include <luisa/xir/instructions/unreachable.h>
#include <luisa/xir/metadata.h>
#include <luisa/xir/metadata/comment.h>
#include <luisa/xir/metadata/location.h>
#include <luisa/xir/metadata/name.h>
#include <luisa/xir/module.h>
#include <luisa/xir/pool.h>
#include <luisa/xir/translators/ast2xir.h>
#include <luisa/xir/translators/json2xir.h>
#include <luisa/xir/translators/xir2json.h>
#include <luisa/xir/translators/xir2text.h>
#include <luisa/xir/use.h>
#include <luisa/xir/user.h>
#include <luisa/xir/value.h>
206 changes: 206 additions & 0 deletions src/backends/common/spirv_codegen/codegen_lib.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#include "codegen_lib.h"
#include <luisa/ast/type.h>
#include <luisa/core/stl/format.h>

namespace lc::spirv {
CodegenLib::CodegenLib() {
Register::reset_register();
}

Register CodegenLib::_mark_runtime_arr(Type const *element) {
auto iter = _runtime_arr_types.try_emplace(element, vstd::lazy_eval([&] {
return Register::new_register();
}));
auto reg = iter.first.value();
if (!iter.second) return reg;
auto reg_str = reg.to_str();
_annotation_builder << luisa::format("OpDecorate {} ArrayStride {}\n", reg_str, element->size());
auto ele_type = mark_type(element);
_types_builder << luisa::format("{} = OpTypeRuntimeArray {}\n", reg_str, ele_type.to_str());
return reg;
}

auto CodegenLib::_mark_image_type(Type const *texture, bool unordered_access) -> BufferType const & {
auto tex_type = _storage_types.emplace(texture).value();
auto &struct_type = unordered_access ? tex_type.uav_struct_type : tex_type.srv_struct_type;
auto &ptr_type = unordered_access ? tex_type.uav_ptr_type : tex_type.srv_ptr_type;
if (struct_type != nullptr && ptr_type != nullptr) return tex_type;
struct_type = Register::new_register();
ptr_type = Register::new_register();
auto struct_type_str = struct_type.to_str();
auto ptr_type_str = ptr_type.to_str();
luisa::string img_format;
if (unordered_access) {
switch (texture->element()->tag()) {
case Type::Tag::INT32: {
img_format = "Rgba32i";
} break;
case Type::Tag::UINT32: {
img_format = "Rgba32iu";
} break;
case Type::Tag::FLOAT32: {
img_format = "Rgba32f";
} break;
default: {
LUISA_ERROR("Bad texture element type.");
} break;
}
} else {
img_format = "Unknown";
}
auto ele_type = mark_type(texture->element());
_types_builder
<< luisa::format("{} = OpTypeImage {} {}D 2 0 0 {} {}\n",
struct_type_str,
ele_type.to_str(),
texture->dimension(),
unordered_access ? 1 : 2,
img_format)
<< luisa::format("{} = OpTypePointer UniformConstant {}\n",
ptr_type_str,
struct_type_str);
return tex_type;
}

auto CodegenLib::_mark_buffer_type(Type const *buffer, bool unordered_access) -> BufferType const & {
auto buffer_type = _storage_types.emplace(buffer).value();
auto &struct_type = unordered_access ? buffer_type.uav_struct_type : buffer_type.srv_struct_type;
auto &ptr_type = unordered_access ? buffer_type.uav_ptr_type : buffer_type.srv_ptr_type;
if (struct_type != nullptr && ptr_type != nullptr) return buffer_type;
struct_type = Register::new_register();
ptr_type = Register::new_register();
auto struct_type_str = struct_type.to_str();
auto ptr_type_str = ptr_type.to_str();
auto runtime_arr_reg_str = _mark_runtime_arr(buffer->element()).to_str();
_types_builder << luisa::format("{} = OpTypeStruct {}\n", struct_type_str, runtime_arr_reg_str)
<< luisa::format("{} = OpTypePointer Uniform {}\n", ptr_type_str, struct_type_str);
_annotation_builder << luisa::format("OpMemberDecorate {} 0 Offset 0\n", struct_type_str);
if (!unordered_access) {
_annotation_builder << luisa::format("OpMemberDecorate {} 0 NonWritable\n", struct_type_str);
}
_annotation_builder << luisa::format("OpDecorate {} BufferBlock\n", struct_type_str);
return buffer_type;
}

Register CodegenLib::_mark_constant_value(LiteralVariantType const &value) {
auto iter = _constant_values.try_emplace(value, vstd::lazy_eval([] {
return Register::new_register();
}));
auto reg = iter.first.value();
if (!iter.second) return reg;
luisa::visit([&]<typename T>(T const &t) {
auto ele_type = mark_type(Type::of<T>());
_types_builder << luisa::format("{} = OpConstant {} {}\n", reg.to_str(), ele_type.to_str(), t);
},
value);
return reg;
}

Register CodegenLib::mark_type(Type const *type, bool unordered_access) {
switch (type->tag()) {
case Type::Tag::BUFFER: {
auto &&t = _mark_buffer_type(type, unordered_access);
return unordered_access ? t.uav_ptr_type : t.srv_ptr_type;
}
case Type::Tag::TEXTURE: {
auto &&t = _mark_image_type(type, unordered_access);
return unordered_access ? t.uav_ptr_type : t.srv_ptr_type;
}
case Type::Tag::BINDLESS_ARRAY: {
// TODO: backend bindless
LUISA_ERROR("Not implemented.");
}
case Type::Tag::ACCEL: {
if (_accel_id == nullptr) {
_accel_id = Register::new_register();
_accel_ptr_id = Register::new_register();
auto accel_id_id = _accel_id.to_str();
_types_builder << luisa::format("{} = OpTypeAccelerationStructureKHR\n", accel_id_id)
<< luisa::format("{} = OpTypePointer UniformConstant {}\n", _accel_ptr_id.to_str(), accel_id_id);
}
// TODO: backend accel instance buffer
return _accel_id;
}
}
auto iter =
_types.try_emplace(type, vstd::lazy_eval([&] {
return Register::new_register();
}));
auto reg = iter.first.value();
if (!iter.second)
return reg;
auto reg_str = reg.to_str();
if (type == nullptr) {
_types_builder << luisa::format("{} = OpTypeVoid\n", reg_str);
return reg;
}
switch (type->tag()) {
case Type::Tag::BOOL:
_types_builder << luisa::format("{} = OpTypeBool\n", reg_str);
break;
case Type::Tag::INT8:
_types_builder << luisa::format("{} = OpTypeInt 8 1\n", reg_str);
break;
case Type::Tag::UINT8:
_types_builder << luisa::format("{} = OpTypeInt 8 0\n", reg_str);
break;
case Type::Tag::INT16:
_types_builder << luisa::format("{} = OpTypeInt 16, 1\n", reg_str);
break;
case Type::Tag::UINT16:
_types_builder << luisa::format("{} = OpTypeInt 16 0\n", reg_str);
break;
case Type::Tag::INT32:
_types_builder << luisa::format("{} = OpTypeInt 32 1\n", reg_str);
break;
case Type::Tag::UINT32:
_types_builder << luisa::format("{} = OpTypeInt 32 0\n", reg_str);
break;
case Type::Tag::INT64:
_types_builder << luisa::format("{} = OpTypeInt 64 1\n", reg_str);
break;
case Type::Tag::UINT64:
_types_builder << luisa::format("{} = OpTypeInt 64 0\n", reg_str);
break;
case Type::Tag::FLOAT16:
_types_builder << luisa::format("{} = OpTypeFloat 16\n", reg_str);
break;
case Type::Tag::FLOAT32:
_types_builder << luisa::format("{} = OpTypeFloat 32\n", reg_str);
break;
case Type::Tag::FLOAT64:
_types_builder << luisa::format("{} = OpTypeFloat 64\n", reg_str);
break;
case Type::Tag::VECTOR: {
auto ele_type = mark_type(type->element());
_types_builder << luisa::format("{} = OpTypeVector {} {}\n", reg_str, ele_type.to_str(), type->dimension());
} break;
case Type::Tag::MATRIX: {
auto col_reg = mark_type(Type::vector(Type::of<float>(), type->dimension()));
_types_builder << luisa::format("{} = OpTypeMatrix {} {}\n", reg_str, col_reg.to_str(), type->dimension());
} break;
case Type::Tag::ARRAY: {
_annotation_builder << luisa::format("OpDecorate {} ArrayStride {}\n", reg_str, type->element()->size());
auto ele_type = mark_type(type->element());
_types_builder << luisa::format(
"{} = OpTypeArray {} {}\n",
reg_str,
ele_type.to_str(),
type->dimension());
} break;
case Type::Tag::STRUCTURE: {
size_t count = 0;
size_t offset = 0;
for (auto &mem : type->members()) {
offset = (offset + (mem->alignment() - 1)) & (~(mem->alignment() - 1));
_annotation_builder << luisa::format("OpMemberDecorate {} {} Offset {}\n", reg_str, count, offset);
offset += mem->size();
count++;
}
} break;
}
return reg;
}
CodegenLib::~CodegenLib() {
}
}// namespace lc::spirv
71 changes: 71 additions & 0 deletions src/backends/common/spirv_codegen/codegen_lib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once
#include "string_builder.h"
#include "register.h"
namespace luisa::compute {
class Type;
}// namespace luisa::compute
namespace lc::spirv {
using namespace luisa::compute;

class CodegenLib {
struct BufferType {
Register srv_struct_type;
Register srv_ptr_type;
Register uav_struct_type;
Register uav_ptr_type;
};
using LiteralVariantType = luisa::variant<
double,
int64_t,
uint64_t>;
struct LiteralVariantTypeHash {
size_t operator()(LiteralVariantType const &h) const {
return luisa::visit(
[&]<typename T>(T const &t) {
return luisa::hash64(&t, sizeof(T), h.index());
},
h);
}
};
struct LiteralVariantTypeCompare {
int operator()(LiteralVariantType const &a, LiteralVariantType const &b) const {
if (a.index() > b.index()) {
return 1;
} else if (a.index() < b.index()) {
return -1;
} else {
return luisa::visit(
[&]<typename T>(T const &t) {
auto &&b_value = *(b.get_as<std::add_pointer_t<T>>());
if (t < b_value) return -1;
if (t > b_value) return 1;
return 0;
},
a);
}
}
};
vstd::StringBuilder _annotation_builder;
vstd::StringBuilder _types_builder;
vstd::StringBuilder _body_builder;
vstd::HashMap<Type const *, Register> _types;
vstd::HashMap<Type const *, Register> _runtime_arr_types;// usually used by buffer
vstd::HashMap<Type const *, BufferType> _storage_types;
vstd::HashMap<LiteralVariantType, Register, LiteralVariantTypeHash, LiteralVariantTypeCompare> _constant_values;
Register _accel_id;
Register _accel_ptr_id;
Register _accel_inst_buffer_id;
Register _accel_inst_ptr_id;
Register _rw_accel_inst_buffer_id;
Register _rw_accel_inst_ptr_id;
Register _mark_runtime_arr(Type const *element);
BufferType const& _mark_buffer_type(Type const *buffer, bool unordered_access);
BufferType const& _mark_image_type(Type const *texture, bool unordered_access);
Register _mark_constant_value(LiteralVariantType const& value);

public:
CodegenLib();
~CodegenLib();
Register mark_type(Type const *type, bool unordered_access = false);
};
}// namespace lc::spirv
Loading

0 comments on commit e8afe30

Please sign in to comment.