From cbf18dbef6aa05369c6c2563bcb6c9a67926962b Mon Sep 17 00:00:00 2001 From: Keith Winstein Date: Fri, 25 Aug 2023 00:33:16 -0700 Subject: [PATCH] wasm2c: implement the tail-call proposal --- include/wabt/ir.h | 7 + src/binary-reader-ir.cc | 10 + src/c-writer.cc | 441 +++++++++++++++++- src/prebuilt/wasm2c_source_declarations.cc | 35 +- src/template/wasm2c.declarations.c | 25 +- src/tools/wasm2c.cc | 2 +- test/run-spec-wasm2c.py | 3 + test/strip/remove_section.txt | 1 + test/wasm2c/add.txt | 33 +- test/wasm2c/check-imports.txt | 35 +- test/wasm2c/export-names.txt | 25 +- test/wasm2c/hello.txt | 64 ++- test/wasm2c/minimal.txt | 25 +- test/wasm2c/spec/tail-call/return_call.txt | 6 + .../spec/tail-call/return_call_indirect.txt | 6 + wasm2c/examples/callback/main.c | 3 +- wasm2c/examples/fac/fac.c | 25 +- wasm2c/wasm-rt-impl.c | 2 +- wasm2c/wasm-rt.h | 16 +- 19 files changed, 704 insertions(+), 60 deletions(-) create mode 100644 test/wasm2c/spec/tail-call/return_call.txt create mode 100644 test/wasm2c/spec/tail-call/return_call_indirect.txt diff --git a/include/wabt/ir.h b/include/wabt/ir.h index 28818c91cf..ee4e1dc969 100644 --- a/include/wabt/ir.h +++ b/include/wabt/ir.h @@ -893,6 +893,13 @@ struct Func { BindingHash bindings; ExprList exprs; Location loc; + + // For a subset of features, the BinaryReaderIR tracks whether they are + // actually used by the function. wasm2c (CWriter) uses this information to + // limit its output in some cases. + struct { + bool tailcall = false; + } features_used; }; struct Global { diff --git a/src/binary-reader-ir.cc b/src/binary-reader-ir.cc index 2eb2342b35..e9b8ee562f 100644 --- a/src/binary-reader-ir.cc +++ b/src/binary-reader-ir.cc @@ -892,12 +892,22 @@ Result BinaryReaderIR::OnCallRefExpr() { } Result BinaryReaderIR::OnReturnCallExpr(Index func_index) { + if (current_func_) { + // syntactically, a return_call expr can occur in an init expression + // (outside a function) + current_func_->features_used.tailcall = true; + } return AppendExpr( std::make_unique(Var(func_index, GetLocation()))); } Result BinaryReaderIR::OnReturnCallIndirectExpr(Index sig_index, Index table_index) { + if (current_func_) { + // syntactically, a return_call_indirect expr can occur in an init + // expression (outside a function) + current_func_->features_used.tailcall = true; + } auto expr = std::make_unique(); SetFuncDeclaration(&expr->decl, Var(sig_index, GetLocation())); expr->table = Var(table_index, GetLocation()); diff --git a/src/c-writer.cc b/src/c-writer.cc index 7076fdb54b..939496c2e9 100644 --- a/src/c-writer.cc +++ b/src/c-writer.cc @@ -103,6 +103,11 @@ struct ExternalRef : GlobalName { using GlobalName::GlobalName; }; +struct TailCallRef : GlobalName { + explicit TailCallRef(const std::string& name) + : GlobalName(ModuleFieldType::Func, name) {} +}; + struct ExternalInstancePtr : GlobalName { using GlobalName::GlobalName; }; @@ -265,6 +270,9 @@ class CWriter { static std::string ExportName(std::string_view module_name, std::string_view export_name); std::string ExportName(std::string_view export_name) const; + static std::string TailCallExportName(std::string_view module_name, + std::string_view export_name); + std::string TailCallExportName(std::string_view export_name) const; std::string ModuleInstanceTypeName() const; static std::string ModuleInstanceTypeName(std::string_view module_name); void ClaimName(SymbolSet& set, @@ -296,6 +304,7 @@ class CWriter { std::string GetGlobalName(ModuleFieldType, const std::string&) const; std::string GetLocalName(const std::string&, bool is_label) const; + std::string GetTailCallRef(const std::string&) const; void Indent(int size = INDENT_SIZE); void Dedent(int size = INDENT_SIZE); @@ -332,6 +341,7 @@ class CWriter { void Write(const GlobalName&); void Write(const TagSymbol&); void Write(const ExternalRef&); + void Write(const TailCallRef&); void Write(const ExternalInstancePtr&); void Write(const ExternalInstanceRef&); void Write(Type); @@ -350,6 +360,7 @@ class CWriter { void WriteMultiCTop(); void WriteMultiCTopEmpty(); void WriteMultivalueType(const TypeVector&); + void WriteMultivalueParamTypes(); void WriteMultivalueResultTypes(); void WriteTagTypes(); void WriteFuncTypeDecls(); @@ -360,8 +371,10 @@ class CWriter { void ComputeUniqueImports(); void BeginInstance(); void WriteImports(); + void WriteTailCallWeakImports(); void WriteFuncDeclarations(); void WriteFuncDeclaration(const FuncDeclaration&, const std::string&); + void WriteTailCallFuncDeclaration(const std::string&); void WriteImportFuncDeclaration(const FuncDeclaration&, const std::string& module_name, const std::string&); @@ -390,6 +403,7 @@ class CWriter { void WriteElemInitializers(); void WriteElemTableInit(bool, const ElemSegment*, const Table*); void WriteExports(CWriterPhase); + void WriteTailCallExports(CWriterPhase); void WriteInitDecl(); void WriteFreeDecl(); void WriteGetFuncTypeDecl(); @@ -400,6 +414,7 @@ class CWriter { void WriteImportProperties(CWriterPhase); void WriteFuncs(); void Write(const Func&); + void WriteTailCallee(const Func&); void WriteParamsAndLocals(); void WriteParams(const std::vector& index_to_name); void WriteParamSymbols(const std::vector& index_to_name); @@ -407,7 +422,10 @@ class CWriter { void WriteLocals(const std::vector& index_to_name); void WriteStackVarDeclarations(); void Write(const ExprList&); + void WriteTailCallAsserts(const FuncSignature&); + void WriteTailCallStack(); void WriteUnwindTryCatchStack(const Label*); + void FinishReturnCall(); void Spill(const TypeVector&, bool); void Unspill(const TypeVector&, bool); @@ -501,6 +519,8 @@ class CWriter { name_to_output_file_index_; bool simd_used_in_header_; + + bool in_tail_callee_; }; // TODO: if WABT begins supporting debug names for labels, @@ -516,6 +536,9 @@ static constexpr char kLabelSuffix = kParamSuffix + 1; static constexpr char kGlobalSymbolPrefix[] = "w2c_"; static constexpr char kLocalSymbolPrefix[] = "var_"; static constexpr char kAdminSymbolPrefix[] = "wasm2c_"; +static constexpr char kTailCallSymbolPrefix[] = "wasm2c_tailcall_"; +static constexpr char kTailCallFallbackPrefix[] = "wasm2c_fallback_"; +static constexpr unsigned int kTailCallStackSize = 1024; size_t CWriter::MarkTypeStack() const { return type_stack_.size(); @@ -655,6 +678,18 @@ std::string CWriter::ExportName(std::string_view module_name, MangleName(export_name); } +/* The C symbol for a tail-callee export from this module. */ +std::string CWriter::TailCallExportName(std::string_view export_name) const { + return kTailCallSymbolPrefix + ExportName(export_name); +} + +/* The C symbol for a tail-callee export from an arbitrary module. */ +// static +std::string CWriter::TailCallExportName(std::string_view module_name, + std::string_view export_name) { + return kTailCallSymbolPrefix + ExportName(module_name, export_name); +} + /* The type name of an instance of this module. */ std::string CWriter::ModuleInstanceTypeName() const { return kGlobalSymbolPrefix + module_prefix_; @@ -925,6 +960,10 @@ std::string CWriter::GetLocalName(const std::string& name, return local_sym_map_.at(mangled); } +std::string CWriter::GetTailCallRef(const std::string& name) const { + return kTailCallSymbolPrefix + GetGlobalName(ModuleFieldType::Func, name); +} + std::string CWriter::DefineParamName(std::string_view name) { return DefineLocalScopeName(name, false); } @@ -1069,6 +1108,10 @@ void CWriter::Write(const ExternalRef& name) { } } +void CWriter::Write(const TailCallRef& name) { + Write(GetTailCallRef(name.name)); +} + void CWriter::Write(const ExternalInstanceRef& name) { if (IsImport(name.name)) { Write("(*instance->", GlobalName(name), ")"); @@ -1368,7 +1411,13 @@ void CWriter::WriteInitExprTerminal(const Expr* expr) { Write("(wasm_rt_funcref_t){", FuncTypeExpr(func_type), ", ", "(wasm_rt_function_ptr_t)", - ExternalRef(ModuleFieldType::Func, func->name), ", "); + ExternalRef(ModuleFieldType::Func, func->name), ", {"); + if (IsImport(func->name) || func->features_used.tailcall) { + Write(TailCallRef(func->name)); + } else { + Write("NULL"); + } + Write("}, "); if (IsImport(func->name)) { Write("instance->", GlobalName(ModuleFieldType::Import, @@ -1452,6 +1501,15 @@ void CWriter::WriteMultivalueType(const TypeVector& types) { Write(CloseBrace(), ";", Newline(), "#endif /* ", name, " */", Newline()); } +void CWriter::WriteMultivalueParamTypes() { + for (TypeEntry* type : module_->types) { + FuncType* func_type = cast(type); + if (func_type->GetNumParams() > 1) { + WriteMultivalueType(func_type->sig.param_types); + } + } +} + void CWriter::WriteMultivalueResultTypes() { for (TypeEntry* type : module_->types) { FuncType* func_type = cast(type); @@ -1732,6 +1790,8 @@ void CWriter::WriteImports() { ExportName(import->module_name, import->field_name)); Write(";"); Write(Newline()); + WriteTailCallFuncDeclaration(GetTailCallRef(func.name)); + Write(";", Newline()); } else if (import->kind() == ExternalKind::Tag) { Write(Newline(), "/* import: '", SanitizeForComment(import->module_name), "' '", SanitizeForComment(import->field_name), "' */", Newline()); @@ -1742,6 +1802,67 @@ void CWriter::WriteImports() { } } +void CWriter::WriteTailCallWeakImports() { + for (const Import* import : unique_imports_) { + if (import->kind() != ExternalKind::Func) { + continue; + } + const Func& func = cast(import)->func; + Write(Newline(), "/* handler for missing tail-call on import: '", + SanitizeForComment(import->module_name), "' '", + SanitizeForComment(import->field_name), "' */", Newline()); + Write("#ifdef _MSC_VER", Newline(), + "#pragma comment(linker, \"/alternatename:", + TailCallExportName(import->module_name, import->field_name), "=", + kTailCallFallbackPrefix, module_prefix_, "_", + ExportName(import->module_name, import->field_name), "\")", + Newline()); + WriteTailCallFuncDeclaration( + kTailCallFallbackPrefix + module_prefix_ + '_' + + ExportName(import->module_name, import->field_name)); + Write(Newline(), "#else", Newline()); + Write("__attribute__((weak)) "); + WriteTailCallFuncDeclaration( + TailCallExportName(import->module_name, import->field_name)); + Write(Newline(), "#endif", Newline(), OpenBrace()); + + Index num_params = func.GetNumParams(); + Index num_results = func.GetNumResults(); + + if (num_params == 1) { + Write(func.GetParamType(0), " ", "param", " = *(", func.GetParamType(0), + "*)tail_call_stack;", Newline()); + } else if (num_params > 1) { + Write(func.decl.sig.param_types, " *params = tail_call_stack;", + Newline()); + } + + if (num_results >= 1) { + Write(func.decl.sig.result_types, " result = "); + } + + Write(ExportName(import->module_name, import->field_name), + "(*instance_ptr"); + + if (num_params == 1) { + Write(", param"); + } else if (num_params > 1) { + for (Index i = 0; i < num_params; ++i) { + Writef(", params->%c%d", MangleType(func.GetParamType(i)), i); + } + } + + Write(");", Newline()); + + if (num_results >= 1) { + Write("wasm_rt_memcpy(tail_call_stack, &result, sizeof(result));", + Newline()); + } + + Write(CloseBrace(), Newline()); + } +} + void CWriter::WriteFuncDeclarations() { if (module_->funcs.size() == module_->num_func_imports) return; @@ -1756,6 +1877,12 @@ void CWriter::WriteFuncDeclarations() { WriteFuncDeclaration( func->decl, DefineGlobalScopeName(ModuleFieldType::Func, func->name)); Write(";", Newline()); + + if (func->features_used.tailcall) { + Write(InternalSymbolScope()); + WriteTailCallFuncDeclaration(GetTailCallRef(func->name)); + Write(";", Newline()); + } } ++func_index; } @@ -1769,6 +1896,12 @@ void CWriter::WriteFuncDeclaration(const FuncDeclaration& decl, Write(")"); } +void CWriter::WriteTailCallFuncDeclaration(const std::string& mangled_name) { + Write("void ", mangled_name, + "(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t " + "*next)"); +} + void CWriter::WriteImportFuncDeclaration(const FuncDeclaration& decl, const std::string& module_name, const std::string& name) { @@ -2131,7 +2264,14 @@ void CWriter::WriteElemInitializers() { const Func* func = module_->GetFunc(cast(&expr)->var); const FuncType* func_type = module_->GetFuncType(func->decl.type_var); Write("{", FuncTypeExpr(func_type), ", (wasm_rt_function_ptr_t)", - ExternalRef(ModuleFieldType::Func, func->name), ", "); + ExternalRef(ModuleFieldType::Func, func->name), ", {"); + if (IsImport(func->name) || func->features_used.tailcall) { + Write(TailCallRef(func->name)); + } else { + Write("NULL"); + } + Write("}, "); + if (IsImport(func->name)) { Write("offsetof(", ModuleInstanceTypeName(), ", ", GlobalName(ModuleFieldType::Import, @@ -2143,7 +2283,7 @@ void CWriter::WriteElemInitializers() { Write("},", Newline()); } break; case ExprType::RefNull: - Write("{NULL, NULL, 0},", Newline()); + Write("{NULL, NULL, {NULL}, 0},", Newline()); break; default: WABT_UNREACHABLE; @@ -2369,6 +2509,35 @@ void CWriter::WriteExports(CWriterPhase kind) { } } +void CWriter::WriteTailCallExports(CWriterPhase kind) { + for (const Export* export_ : module_->exports) { + if (export_->kind != ExternalKind::Func) { + continue; + } + + const Func* func = module_->GetFunc(export_->var); + + if (!func->features_used.tailcall) { + continue; + } + + const std::string mangled_name = TailCallExportName(export_->name); + + Write(Newline(), "/* export for tail-call of '", + SanitizeForComment(export_->name), "' */", Newline()); + if (kind == CWriterPhase::Declarations) { + WriteTailCallFuncDeclaration(mangled_name); + Write(";", Newline()); + } else { + WriteTailCallFuncDeclaration(mangled_name); + Write(" ", OpenBrace()); + const Func* func = module_->GetFunc(export_->var); + Write(TailCallRef(func->name), "(instance_ptr, tail_call_stack, next);", + Newline(), CloseBrace(), Newline()); + } + } +} + void CWriter::WriteInit() { Write(Newline(), "void ", kAdminSymbolPrefix, module_prefix_, "_instantiate(", ModuleInstanceTypeName(), "* instance"); @@ -2582,6 +2751,9 @@ void CWriter::WriteFuncs() { if (!is_import) { stream_ = c_streams_.at(c_stream_assignment.at(func_index)); Write(*func); + if (func->features_used.tailcall) { + WriteTailCallee(*func); + } } ++func_index; } @@ -2622,6 +2794,26 @@ void CWriter::Unspill(const TypeVector& types, bool ptr) { Unspill(types, ptr, [&](auto i) { return StackVar(types.size() - i - 1); }); } +void CWriter::WriteTailCallAsserts(const FuncSignature& sig) { + if (sig.param_types.size()) { + Write("wasm_static_assert(sizeof(", sig.param_types, + ") <= ", kTailCallStackSize, ");", Newline()); + } + if (sig.result_types.size() && sig.result_types != sig.param_types) { + Write("wasm_static_assert(sizeof(", sig.result_types, + ") <= ", kTailCallStackSize, ");", Newline()); + } +} + +void CWriter::WriteTailCallStack() { + Write("void *instance_ptr_storage;", Newline()); + Write("void **instance_ptr = &instance_ptr_storage;", Newline()); + Write("char tail_call_stack[", std::to_string(kTailCallStackSize), "];", + Newline()); + Write("wasm_rt_tailcallee_t next_storage;", Newline(), Newline()); + Write("wasm_rt_tailcallee_t *next = &next_storage;", Newline()); +} + void CWriter::WriteUnwindTryCatchStack(const Label* label) { assert(try_catch_stack_.size() >= label->try_catch_stack_size); @@ -2633,8 +2825,31 @@ void CWriter::WriteUnwindTryCatchStack(const Label* label) { } } +void CWriter::FinishReturnCall() { + if (in_tail_callee_) { + Write("return;", Newline()); + return; + } + + Write("while (next->fn) { ", + "next->fn( instance_ptr, tail_call_stack, next ); }", Newline()); + PushTypes(func_->decl.sig.result_types); + const Index num_results = func_->decl.sig.result_types.size(); + if (num_results == 1) { + Write(StackVar(0), " = *(", StackType(0), "*)tail_call_stack;", Newline()); + } else if (num_results > 1) { + Write(OpenBrace(), func_->decl.sig.result_types, " *tmp = tail_call_stack;", + Newline()); + Unspill(func_->decl.sig.result_types, true); + Write(CloseBrace(), Newline()); + } + + Write("goto ", LabelName(kImplicitFuncLabel), ";", Newline()); +} + void CWriter::Write(const Func& func) { func_ = &func; + in_tail_callee_ = false; local_syms_.clear(); local_sym_map_.clear(); stack_var_sym_map_.clear(); @@ -2700,6 +2915,110 @@ void CWriter::Write(const Func& func) { func_ = nullptr; } +void CWriter::WriteTailCallee(const Func& func) { + func_ = &func; + in_tail_callee_ = true; + local_syms_.clear(); + local_sym_map_.clear(); + stack_var_sym_map_.clear(); + func_sections_.clear(); + func_includes_.clear(); + + Stream* prev_stream = stream_; + + Write(Newline()); + + PushFuncSection(); + WriteTailCallFuncDeclaration(GetTailCallRef(func.name)); + Write(" ", OpenBrace()); + WriteTailCallAsserts(func.decl.sig); + Write(ModuleInstanceTypeName(), "* instance = *instance_ptr;", Newline()); + + std::vector index_to_name; + MakeTypeBindingReverseMapping(func_->GetNumParamsAndLocals(), func_->bindings, + &index_to_name); + if (func_->GetNumParams() == 1) { + Write(func_->GetParamType(0), " ", DefineParamName(index_to_name[0]), + " = *(", func_->GetParamType(0), "*)tail_call_stack;", Newline()); + } else if (func_->GetNumParams() > 1) { + for (Type type : {Type::I32, Type::I64, Type::F32, Type::F64, Type::V128, + Type::FuncRef, Type::ExternRef}) { + Index param_index = 0; + size_t count = 0; + for (Type param_type : func_->decl.sig.param_types) { + if (param_type == type) { + if (count == 0) { + Write(type, " "); + Indent(4); + } else { + Write(", "); + if ((count % 8) == 0) + Write(Newline()); + } + + Write(DefineParamName(index_to_name[param_index])); + ++count; + } + ++param_index; + } + if (count != 0) { + Dedent(4); + Write(";", Newline()); + } + } + Write(OpenBrace(), func_->decl.sig.param_types, " *tmp = tail_call_stack;", + Newline()); + Unspill(func_->decl.sig.param_types, true, + [&](auto i) { return ParamName(index_to_name[i]); }); + Write(CloseBrace(), Newline()); + } + + WriteLocals(index_to_name); + + PushFuncSection(); + + std::string label = DefineLabelName(kImplicitFuncLabel); + ResetTypeStack(0); + std::string empty; // Must not be temporary, since address is taken by Label. + PushLabel(LabelType::Func, empty, func.decl.sig); + Write(func.exprs, LabelDecl(label)); + PopLabel(); + ResetTypeStack(0); + PushTypes(func.decl.sig.result_types); + + // Return the top of the stack implicitly. + Index num_results = func.GetNumResults(); + if (num_results == 1) { + Write("wasm_rt_memcpy(tail_call_stack, &", StackVar(0), ", sizeof(", + StackVar(0), "));", Newline()); + } else if (num_results > 1) { + Write(OpenBrace(), func.decl.sig.result_types, " *tmp = tail_call_stack;", + Newline()); + Spill(func.decl.sig.result_types, true); + Write(CloseBrace(), Newline()); + } + Write("next->fn = NULL;", Newline()); + + stream_ = prev_stream; + + for (size_t i = 0; i < func_sections_.size(); ++i) { + auto& [condition, stream] = func_sections_.at(i); + std::unique_ptr buf = stream.ReleaseOutputBuffer(); + if (condition.empty() || func_includes_.count(condition)) { + stream_->WriteData(buf->data.data(), buf->data.size()); + } + + if (i == 0) { + WriteStackVarDeclarations(); // these come immediately after section #0 + // (return type/name/params/locals) + } + } + + Write(CloseBrace(), Newline()); + + func_ = nullptr; +} + void CWriter::WriteParamsAndLocals() { std::vector index_to_name; MakeTypeBindingReverseMapping(func_->GetNumParamsAndLocals(), func_->bindings, @@ -3363,7 +3682,13 @@ void CWriter::Write(const ExprList& exprs) { Write(StackVar(0), " = (wasm_rt_funcref_t){", FuncTypeExpr(func_type), ", (wasm_rt_function_ptr_t)", - ExternalRef(ModuleFieldType::Func, func->name), ", "); + ExternalRef(ModuleFieldType::Func, func->name), ", {"); + if (IsImport(func->name) || func->features_used.tailcall) { + Write(TailCallRef(func->name)); + } else { + Write("NULL"); + } + Write("}, "); if (IsImport(func->name)) { Write("instance->", GlobalName(ModuleFieldType::Import, @@ -3557,10 +3882,110 @@ void CWriter::Write(const ExprList& exprs) { break; } + case ExprType::ReturnCall: { + const auto inst = cast(&expr); + const Func& func = *module_->GetFunc(inst->var); + + const FuncDeclaration& decl = func.decl; + assert(decl.sig.result_types == func_->decl.sig.result_types); + WriteTailCallAsserts(decl.sig); + WriteUnwindTryCatchStack(FindLabel(Var(label_stack_.size() - 1, {}))); + + if (!IsImport(func.name) && !func.features_used.tailcall) { + // make normal call, then return + Write(ExprList{std::make_unique(inst->var, inst->loc)}); + Write("goto ", LabelName(kImplicitFuncLabel), ";", Newline()); + return; + } + + Write(OpenBrace()); + if (!in_tail_callee_) { + WriteTailCallStack(); + } + + const Index num_params = decl.GetNumParams(); + if (num_params == 1) { + Write("wasm_rt_memcpy(tail_call_stack, &", StackVar(0), ", sizeof(", + StackVar(0), "));", Newline()); + } else if (num_params > 1) { + Write(OpenBrace(), decl.sig.param_types, " *tmp = (", + decl.sig.param_types, " *)tail_call_stack;", Newline()); + Spill(decl.sig.param_types, true); + Write(CloseBrace(), Newline()); + } + + Write("next->fn = ", TailCallRef(func.name), ";", Newline()); + if (IsImport(func.name)) { + Write("*instance_ptr = ", + GlobalName(ModuleFieldType::Import, + import_module_sym_map_.at(func.name)), + ";", Newline()); + } + DropTypes(num_params); + FinishReturnCall(); + Write(CloseBrace(), Newline()); + return; + } + + case ExprType::ReturnCallIndirect: { + const auto inst = cast(&expr); + const FuncDeclaration& decl = inst->decl; + assert(decl.sig.result_types == func_->decl.sig.result_types); + assert(decl.has_func_type); + const Index num_params = decl.GetNumParams(); + WriteTailCallAsserts(decl.sig); + WriteUnwindTryCatchStack(FindLabel(Var(label_stack_.size() - 1, {}))); + const Table* table = module_->GetTable(inst->table); + Write("CHECK_CALL_INDIRECT(", + ExternalInstanceRef(ModuleFieldType::Table, table->name), ", ", + FuncTypeExpr(module_->GetFuncType(decl.type_var)), ", ", + StackVar(0), ");", Newline()); + + Write("if (!", ExternalInstanceRef(ModuleFieldType::Table, table->name), + ".data[", StackVar(0), "].func_tailcallee.fn) ", OpenBrace()); + auto ci = std::make_unique(inst->loc); + std::tie(ci->decl, ci->table) = std::make_pair(inst->decl, inst->table); + Write(ExprList{std::move(ci)}); + Write("goto ", LabelName(kImplicitFuncLabel), ";", Newline()); + Write(CloseBrace(), Newline()); + + DropTypes(decl.GetNumResults()); + PushTypes(decl.sig.param_types); + PushType(Type::I32); + + Write(OpenBrace()); + if (!in_tail_callee_) { + WriteTailCallStack(); + } + + if (num_params == 1) { + Write("wasm_rt_memcpy(tail_call_stack, &", + StackVar(num_params, decl.GetResultType(0)), ", sizeof(", + decl.GetResultType(0), "));", Newline()); + } else if (num_params > 1) { + Write(OpenBrace(), decl.sig.param_types, " *tmp = (", + decl.sig.param_types, " *)tail_call_stack;", Newline()); + Spill(decl.sig.param_types, true, + [&](auto i) { return StackVar(num_params - i); }); + Write(CloseBrace(), Newline()); + } + + assert(decl.has_func_type); + Write("next->fn = ", + ExternalInstanceRef(ModuleFieldType::Table, table->name), + ".data[", StackVar(0), "].func_tailcallee.fn;", Newline()); + Write("*instance_ptr = ", + ExternalInstanceRef(ModuleFieldType::Table, table->name), + ".data[", StackVar(0), "].module_instance;", Newline()); + + DropTypes(num_params + 1); + FinishReturnCall(); + Write(CloseBrace(), Newline()); + return; + } + case ExprType::AtomicWait: case ExprType::AtomicNotify: - case ExprType::ReturnCall: - case ExprType::ReturnCallIndirect: case ExprType::CallRef: UNIMPLEMENTED("..."); break; @@ -5341,6 +5766,7 @@ void CWriter::WriteCHeader() { WriteImports(); WriteImportProperties(CWriterPhase::Declarations); WriteExports(CWriterPhase::Declarations); + WriteTailCallExports(CWriterPhase::Declarations); Write(Newline()); Write(s_header_bottom); Write(Newline(), "#endif /* ", guard, " */", Newline()); @@ -5359,6 +5785,7 @@ void CWriter::WriteCSource() { WriteFuncDeclarations(); WriteDataInitializerDecls(); WriteElemInitializerDecls(); + WriteMultivalueParamTypes(); /* Write the module-wide material to the first output stream */ stream_ = c_streams_.front(); @@ -5369,6 +5796,8 @@ void CWriter::WriteCSource() { WriteDataInitializers(); WriteElemInitializers(); WriteExports(CWriterPhase::Definitions); + WriteTailCallExports(CWriterPhase::Definitions); + WriteTailCallWeakImports(); WriteInitInstanceImport(); WriteImportProperties(CWriterPhase::Definitions); WriteInit(); diff --git a/src/prebuilt/wasm2c_source_declarations.cc b/src/prebuilt/wasm2c_source_declarations.cc index 5970f9a386..436904de8f 100644 --- a/src/prebuilt/wasm2c_source_declarations.cc +++ b/src/prebuilt/wasm2c_source_declarations.cc @@ -35,15 +35,23 @@ R"w2c_template( return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); R"w2c_template(} )w2c_template" R"w2c_template( -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ )w2c_template" R"w2c_template( (LIKELY((x) < table.size && table.data[x].func && \ )w2c_template" R"w2c_template( func_types_eq(ft, table.data[x].func_type)) || \ )w2c_template" -R"w2c_template( TRAP(CALL_INDIRECT), \ +R"w2c_template( TRAP(CALL_INDIRECT)) )w2c_template" -R"w2c_template( ((t)table.data[x].func)(__VA_ARGS__)) +R"w2c_template( +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) +)w2c_template" +R"w2c_template( +#define CALL_INDIRECT(table, t, ft, x, ...) \ +)w2c_template" +R"w2c_template( (CHECK_CALL_INDIRECT(table, ft, x), \ +)w2c_template" +R"w2c_template( DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) )w2c_template" R"w2c_template( #ifdef SUPPORT_MEMORY64 @@ -940,6 +948,8 @@ R"w2c_template( wasm_rt_func_type_t type; )w2c_template" R"w2c_template( wasm_rt_function_ptr_t func; )w2c_template" +R"w2c_template( wasm_rt_tailcallee_t func_tailcallee; +)w2c_template" R"w2c_template( size_t module_offset; )w2c_template" R"w2c_template(} wasm_elem_segment_expr_t; @@ -971,11 +981,11 @@ R"w2c_template( for (u32 i = 0; i < n; i++) { )w2c_template" R"w2c_template( const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; )w2c_template" -R"w2c_template( dest->data[dest_addr + i] = +R"w2c_template( dest->data[dest_addr + i] = (wasm_rt_funcref_t){ )w2c_template" -R"w2c_template( (wasm_rt_funcref_t){src_expr->type, src_expr->func, +R"w2c_template( src_expr->type, src_expr->func, src_expr->func_tailcallee, )w2c_template" -R"w2c_template( (char*)module_instance + src_expr->module_offset}; +R"w2c_template( (char*)module_instance + src_expr->module_offset}; )w2c_template" R"w2c_template( } )w2c_template" @@ -1125,4 +1135,17 @@ R"w2c_template(#define FUNC_TYPE_T(x) static const char x[] )w2c_template" R"w2c_template(#endif )w2c_template" +R"w2c_template( +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +)w2c_template" +R"w2c_template(#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +)w2c_template" +R"w2c_template(#else +)w2c_template" +R"w2c_template(#define wasm_static_assert(X) \ +)w2c_template" +R"w2c_template( extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +)w2c_template" +R"w2c_template(#endif +)w2c_template" ; diff --git a/src/template/wasm2c.declarations.c b/src/template/wasm2c.declarations.c index 8ea080ca99..c1dd28118e 100644 --- a/src/template/wasm2c.declarations.c +++ b/src/template/wasm2c.declarations.c @@ -20,11 +20,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -502,6 +507,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -518,9 +524,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -600,3 +606,10 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_EXTERN_T(x) const char x[] #define FUNC_TYPE_T(x) static const char x[] #endif + +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif diff --git a/src/tools/wasm2c.cc b/src/tools/wasm2c.cc index 10d9210ede..c1718d60c3 100644 --- a/src/tools/wasm2c.cc +++ b/src/tools/wasm2c.cc @@ -61,7 +61,7 @@ static const char s_description[] = static const std::string supported_features[] = { "multi-memory", "multi-value", "sign-extension", "saturating-float-to-int", "exceptions", "memory64", "extended-const", "simd", - "threads"}; + "threads", "tail-call"}; static bool IsFeatureSupported(const std::string& feature) { return std::find(std::begin(supported_features), std::end(supported_features), diff --git a/test/run-spec-wasm2c.py b/test/run-spec-wasm2c.py index 9f45499914..6a308e554b 100755 --- a/test/run-spec-wasm2c.py +++ b/test/run-spec-wasm2c.py @@ -533,6 +533,7 @@ def main(args): parser.add_argument('--enable-memory64', action='store_true') parser.add_argument('--enable-extended-const', action='store_true') parser.add_argument('--enable-threads', action='store_true') + parser.add_argument('--enable-tail-call', action='store_true') parser.add_argument('--disable-bulk-memory', action='store_true') parser.add_argument('--disable-reference-types', action='store_true') parser.add_argument('--debug-names', action='store_true') @@ -553,6 +554,7 @@ def main(args): '--enable-memory64': options.enable_memory64, '--enable-extended-const': options.enable_extended_const, '--enable-threads': options.enable_threads, + '--enable-tail-call': options.enable_tail_call, '--enable-multi-memory': options.enable_multi_memory, '--disable-bulk-memory': options.disable_bulk_memory, '--disable-reference-types': options.disable_reference_types, @@ -571,6 +573,7 @@ def main(args): '--enable-memory64': options.enable_memory64, '--enable-extended-const': options.enable_extended_const, '--enable-threads': options.enable_threads, + '--enable-tail-call': options.enable_tail_call, '--enable-multi-memory': options.enable_multi_memory}) options.cflags += shlex.split(os.environ.get('WASM2C_CFLAGS', '')) diff --git a/test/strip/remove_section.txt b/test/strip/remove_section.txt index 5a157f3ada..31a2fc380e 100644 --- a/test/strip/remove_section.txt +++ b/test/strip/remove_section.txt @@ -25,6 +25,7 @@ section("five") { "Ut enim ad minim veniam," } remove_section_stripped.wasm: file format wasm 0x1 Sections: + Custom start=0x0000000a end=0x0000002a (size=0x00000020) "two" Custom start=0x0000002c end=0x00000052 (size=0x00000026) "three" Custom start=0x00000054 end=0x0000007a (size=0x00000026) "four" diff --git a/test/wasm2c/add.txt b/test/wasm2c/add.txt index 783f443c6d..14543c780b 100644 --- a/test/wasm2c/add.txt +++ b/test/wasm2c/add.txt @@ -87,11 +87,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -569,6 +574,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -585,9 +591,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -668,8 +674,23 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_T(x) static const char x[] #endif +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif + static u32 w2c_test_add_0(w2c_test*, u32, u32); +#ifndef wasm_multi_ii +#define wasm_multi_ii wasm_multi_ii +struct wasm_multi_ii { + u32 i0; + u32 i1; +}; +#endif /* wasm_multi_ii */ + FUNC_TYPE_T(w2c_test_t0) = "\x92\xfb\x6a\xdf\x49\x07\x0a\x83\xbe\x08\x02\x68\xcd\xf6\x95\x27\x4a\xc2\xf3\xe5\xe4\x7d\x29\x49\xe8\xed\x42\x92\x6a\x9d\xda\xf0"; /* export: 'add' */ diff --git a/test/wasm2c/check-imports.txt b/test/wasm2c/check-imports.txt index bae7f486a4..b336cb7720 100644 --- a/test/wasm2c/check-imports.txt +++ b/test/wasm2c/check-imports.txt @@ -110,11 +110,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -592,6 +597,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -608,9 +614,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -691,10 +697,25 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_T(x) static const char x[] #endif +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif + static u32 w2c_test_f0(w2c_test*, u32); static u32 w2c_test_f1(w2c_test*); static u32 w2c_test_f2(w2c_test*, u32, u32); +#ifndef wasm_multi_ii +#define wasm_multi_ii wasm_multi_ii +struct wasm_multi_ii { + u32 i0; + u32 i1; +}; +#endif /* wasm_multi_ii */ + FUNC_TYPE_T(w2c_test_t0) = "\x07\x80\x96\x7a\x42\xf7\x3e\xe6\x70\x5c\x2f\xac\x83\xf5\x67\xd2\xa2\xa0\x69\x41\x5f\xf8\xe7\x96\x7f\x23\xab\x00\x03\x5f\x4a\x3c"; FUNC_TYPE_T(w2c_test_t1) = "\x72\xab\x00\xdf\x20\x3d\xce\xa1\xf2\x29\xc7\x9d\x13\x40\x7e\x98\xac\x7d\x41\x4a\x53\x2e\x42\x42\x61\x55\x2e\xaa\xeb\xbe\xc6\x35"; FUNC_TYPE_T(w2c_test_t2) = "\x92\xfb\x6a\xdf\x49\x07\x0a\x83\xbe\x08\x02\x68\xcd\xf6\x95\x27\x4a\xc2\xf3\xe5\xe4\x7d\x29\x49\xe8\xed\x42\x92\x6a\x9d\xda\xf0"; @@ -703,7 +724,7 @@ static void init_memories(w2c_test* instance) { } static const wasm_elem_segment_expr_t elem_segment_exprs_w2c_test_e0[] = { - {w2c_test_t1, (wasm_rt_function_ptr_t)w2c_test_f1, 0}, + {w2c_test_t1, (wasm_rt_function_ptr_t)w2c_test_f1, {NULL}, 0}, }; static void init_tables(w2c_test* instance) { diff --git a/test/wasm2c/export-names.txt b/test/wasm2c/export-names.txt index 596db40236..bb4536f924 100644 --- a/test/wasm2c/export-names.txt +++ b/test/wasm2c/export-names.txt @@ -110,11 +110,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -592,6 +597,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -608,9 +614,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -691,6 +697,13 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_T(x) static const char x[] #endif +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif + static void w2c_test__0(w2c_test*); FUNC_TYPE_T(w2c_test_t0) = "\x36\xa9\xe7\xf1\xc9\x5b\x82\xff\xb9\x97\x43\xe0\xc5\xc4\xce\x95\xd8\x3c\x9a\x43\x0a\xac\x59\xf8\x4e\xf3\xcb\xfa\xb6\x14\x50\x68"; diff --git a/test/wasm2c/hello.txt b/test/wasm2c/hello.txt index 752725e862..b16703fb50 100644 --- a/test/wasm2c/hello.txt +++ b/test/wasm2c/hello.txt @@ -62,9 +62,11 @@ wasm_rt_func_type_t wasm2c_test_get_func_type(uint32_t param_count, uint32_t res /* import: 'wasi_snapshot_preview1' 'fd_write' */ u32 w2c_wasi__snapshot__preview1_fd_write(struct w2c_wasi__snapshot__preview1*, u32, u32, u32, u32); +void wasm2c_tailcall_w2c_wasi__snapshot__preview1_fd_write(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t *next); /* import: 'wasi_snapshot_preview1' 'proc_exit' */ void w2c_wasi__snapshot__preview1_proc_exit(struct w2c_wasi__snapshot__preview1*, u32); +void wasm2c_tailcall_w2c_wasi__snapshot__preview1_proc_exit(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t *next); /* export: 'memory' */ wasm_rt_memory_t* w2c_test_memory(w2c_test* instance); @@ -118,11 +120,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -600,6 +607,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -616,9 +624,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -699,8 +707,25 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_T(x) static const char x[] #endif +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif + static void w2c_test_0x5Fstart_0(w2c_test*); +#ifndef wasm_multi_iiii +#define wasm_multi_iiii wasm_multi_iiii +struct wasm_multi_iiii { + u32 i0; + u32 i1; + u32 i2; + u32 i3; +}; +#endif /* wasm_multi_iiii */ + FUNC_TYPE_T(w2c_test_t0) = "\xf6\x98\x1b\xc6\x10\xda\xb7\xb2\x63\x37\xcd\xdc\x72\xca\xe9\x50\x00\x13\xba\x10\x6c\xde\x87\x27\x10\xf8\x86\x2f\xe3\xdb\x94\xe4"; FUNC_TYPE_T(w2c_test_t1) = "\x89\x3a\x3d\x2c\x8f\x4d\x7f\x6d\x6c\x9d\x62\x67\x29\xaf\x3d\x44\x39\x8e\xc3\xf3\xe8\x51\xc1\x99\xb9\xdd\x9f\xd5\x3d\x1f\xd3\xe4"; FUNC_TYPE_T(w2c_test_t2) = "\x36\xa9\xe7\xf1\xc9\x5b\x82\xff\xb9\x97\x43\xe0\xc5\xc4\xce\x95\xd8\x3c\x9a\x43\x0a\xac\x59\xf8\x4e\xf3\xcb\xfa\xb6\x14\x50\x68"; @@ -719,7 +744,7 @@ static void init_data_instances(w2c_test *instance) { } static const wasm_elem_segment_expr_t elem_segment_exprs_w2c_test_e0[] = { - {w2c_test_t0, (wasm_rt_function_ptr_t)w2c_wasi__snapshot__preview1_fd_write, offsetof(w2c_test, w2c_wasi__snapshot__preview1_instance)}, + {w2c_test_t0, (wasm_rt_function_ptr_t)w2c_wasi__snapshot__preview1_fd_write, {wasm2c_tailcall_w2c_wasi__snapshot__preview1_fd_write}, offsetof(w2c_test, w2c_wasi__snapshot__preview1_instance)}, }; static void init_tables(w2c_test* instance) { @@ -740,6 +765,31 @@ void w2c_test_0x5Fstart(w2c_test* instance) { w2c_test_0x5Fstart_0(instance); } +/* handler for missing tail-call on import: 'wasi_snapshot_preview1' 'fd_write' */ +#ifdef _MSC_VER +#pragma comment(linker, "/alternatename:wasm2c_tailcall_w2c_wasi__snapshot__preview1_fd_write=wasm2c_fallback_test_w2c_wasi__snapshot__preview1_fd_write") +void wasm2c_fallback_test_w2c_wasi__snapshot__preview1_fd_write(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t *next) +#else +__attribute__((weak)) void wasm2c_tailcall_w2c_wasi__snapshot__preview1_fd_write(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t *next) +#endif +{ + struct wasm_multi_iiii *params = tail_call_stack; + u32 result = w2c_wasi__snapshot__preview1_fd_write(*instance_ptr, params->i0, params->i1, params->i2, params->i3); + wasm_rt_memcpy(tail_call_stack, &result, sizeof(result)); +} + +/* handler for missing tail-call on import: 'wasi_snapshot_preview1' 'proc_exit' */ +#ifdef _MSC_VER +#pragma comment(linker, "/alternatename:wasm2c_tailcall_w2c_wasi__snapshot__preview1_proc_exit=wasm2c_fallback_test_w2c_wasi__snapshot__preview1_proc_exit") +void wasm2c_fallback_test_w2c_wasi__snapshot__preview1_proc_exit(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t *next) +#else +__attribute__((weak)) void wasm2c_tailcall_w2c_wasi__snapshot__preview1_proc_exit(void **instance_ptr, void *tail_call_stack, wasm_rt_tailcallee_t *next) +#endif +{ + u32 param = *(u32*)tail_call_stack; + w2c_wasi__snapshot__preview1_proc_exit(*instance_ptr, param); +} + static void init_instance_import(w2c_test* instance, struct w2c_wasi__snapshot__preview1* w2c_wasi__snapshot__preview1_instance) { instance->w2c_wasi__snapshot__preview1_instance = w2c_wasi__snapshot__preview1_instance; } diff --git a/test/wasm2c/minimal.txt b/test/wasm2c/minimal.txt index 08c724012d..4e8cd43c36 100644 --- a/test/wasm2c/minimal.txt +++ b/test/wasm2c/minimal.txt @@ -81,11 +81,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -563,6 +568,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -579,9 +585,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -662,6 +668,13 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_T(x) static const char x[] #endif +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif + void wasm2c_test_instantiate(w2c_test* instance) { assert(wasm_rt_is_initialized()); } diff --git a/test/wasm2c/spec/tail-call/return_call.txt b/test/wasm2c/spec/tail-call/return_call.txt new file mode 100644 index 0000000000..6218eeac59 --- /dev/null +++ b/test/wasm2c/spec/tail-call/return_call.txt @@ -0,0 +1,6 @@ +;;; TOOL: run-spec-wasm2c +;;; STDIN_FILE: third_party/testsuite/proposals/tail-call/return_call.wast +;;; ARGS*: --enable-tail-call +(;; STDOUT ;;; +31/31 tests passed. +;;; STDOUT ;;) diff --git a/test/wasm2c/spec/tail-call/return_call_indirect.txt b/test/wasm2c/spec/tail-call/return_call_indirect.txt new file mode 100644 index 0000000000..8cf46bd900 --- /dev/null +++ b/test/wasm2c/spec/tail-call/return_call_indirect.txt @@ -0,0 +1,6 @@ +;;; TOOL: run-spec-wasm2c +;;; STDIN_FILE: third_party/testsuite/proposals/tail-call/return_call_indirect.wast +;;; ARGS*: --enable-tail-call +(;; STDOUT ;;; +47/47 tests passed. +;;; STDOUT ;;) diff --git a/wasm2c/examples/callback/main.c b/wasm2c/examples/callback/main.c index 8c6f9db09a..8d9ab59e1d 100644 --- a/wasm2c/examples/callback/main.c +++ b/wasm2c/examples/callback/main.c @@ -26,7 +26,8 @@ int main(int argc, char** argv) { */ wasm_rt_func_type_t fn_type = wasm2c_callback_get_func_type(1, 0, WASM_RT_I32); - wasm_rt_funcref_t fn_ref = {fn_type, (wasm_rt_function_ptr_t)print, &inst}; + wasm_rt_funcref_t fn_ref = { + fn_type, (wasm_rt_function_ptr_t)print, {NULL}, &inst}; w2c_callback_set_print_function(&inst, fn_ref); /* "say_hello" uses the previously installed callback. */ diff --git a/wasm2c/examples/fac/fac.c b/wasm2c/examples/fac/fac.c index 31bceb949a..254b56887e 100644 --- a/wasm2c/examples/fac/fac.c +++ b/wasm2c/examples/fac/fac.c @@ -39,11 +39,16 @@ static inline bool func_types_eq(const wasm_rt_func_type_t a, return (a == b) || LIKELY(a && b && !memcmp(a, b, 32)); } -#define CALL_INDIRECT(table, t, ft, x, ...) \ +#define CHECK_CALL_INDIRECT(table, ft, x) \ (LIKELY((x) < table.size && table.data[x].func && \ func_types_eq(ft, table.data[x].func_type)) || \ - TRAP(CALL_INDIRECT), \ - ((t)table.data[x].func)(__VA_ARGS__)) + TRAP(CALL_INDIRECT)) + +#define DO_CALL_INDIRECT(table, t, x, ...) ((t)table.data[x].func)(__VA_ARGS__) + +#define CALL_INDIRECT(table, t, ft, x, ...) \ + (CHECK_CALL_INDIRECT(table, ft, x), \ + DO_CALL_INDIRECT(table, t, x, __VA_ARGS__)) #ifdef SUPPORT_MEMORY64 #define RANGE_CHECK(mem, offset, len) \ @@ -521,6 +526,7 @@ static inline void memory_init(wasm_rt_memory_t* dest, typedef struct { wasm_rt_func_type_t type; wasm_rt_function_ptr_t func; + wasm_rt_tailcallee_t func_tailcallee; size_t module_offset; } wasm_elem_segment_expr_t; @@ -537,9 +543,9 @@ static inline void funcref_table_init(wasm_rt_funcref_table_t* dest, TRAP(OOB); for (u32 i = 0; i < n; i++) { const wasm_elem_segment_expr_t* src_expr = &src[src_addr + i]; - dest->data[dest_addr + i] = - (wasm_rt_funcref_t){src_expr->type, src_expr->func, - (char*)module_instance + src_expr->module_offset}; + dest->data[dest_addr + i] = (wasm_rt_funcref_t){ + src_expr->type, src_expr->func, src_expr->func_tailcallee, + (char*)module_instance + src_expr->module_offset}; } } @@ -620,6 +626,13 @@ DEFINE_TABLE_FILL(externref) #define FUNC_TYPE_T(x) static const char x[] #endif +#if (__STDC_VERSION__ >= 201112L) || defined(_Static_assert) +#define wasm_static_assert(X) _Static_assert(X, "assertion failure") +#else +#define wasm_static_assert(X) \ + extern int(*wasm2c_assert(void))[!!sizeof(struct { int x : (X) ? 2 : -1; })]; +#endif + static u32 w2c_fac_fac_0(w2c_fac*, u32); FUNC_TYPE_T(w2c_fac_t0) = "\x07\x80\x96\x7a\x42\xf7\x3e\xe6\x70\x5c\x2f\xac\x83\xf5\x67\xd2\xa2\xa0\x69\x41\x5f\xf8\xe7\x96\x7f\x23\xab\x00\x03\x5f\x4a\x3c"; diff --git a/wasm2c/wasm-rt-impl.c b/wasm2c/wasm-rt-impl.c index 542230871a..99bae70705 100644 --- a/wasm2c/wasm-rt-impl.c +++ b/wasm2c/wasm-rt-impl.c @@ -414,7 +414,7 @@ const char* wasm_rt_strerror(wasm_rt_trap_t trap) { case WASM_RT_TRAP_UNREACHABLE: return "Unreachable instruction executed"; case WASM_RT_TRAP_CALL_INDIRECT: - return "Invalid call_indirect"; + return "Invalid call_indirect or return_call_indirect"; case WASM_RT_TRAP_UNCAUGHT_EXCEPTION: return "Uncaught exception"; case WASM_RT_TRAP_UNALIGNED: diff --git a/wasm2c/wasm-rt.h b/wasm2c/wasm-rt.h index bdeecd569b..1b123f65cc 100644 --- a/wasm2c/wasm-rt.h +++ b/wasm2c/wasm-rt.h @@ -245,6 +245,15 @@ typedef enum { */ typedef void (*wasm_rt_function_ptr_t)(void); +/** + * A pointer to a "tail-callee" function, called by a tail-call + * trampoline or by another tail-callee function. (The definition uses a + * single-member struct to allow a recursive definition.) + */ +typedef struct wasm_rt_tailcallee_t { + void (*fn)(void**, void*, struct wasm_rt_tailcallee_t*); +} wasm_rt_tailcallee_t; + /** * The type of a function (an arbitrary number of param and result types). * This is represented as an opaque 256-bit ID. @@ -259,6 +268,8 @@ typedef struct { /** The function. The embedder must know the actual C signature of the * function and cast to it before calling. */ wasm_rt_function_ptr_t func; + /** An alternate version of the function to be used when tail-called. */ + wasm_rt_tailcallee_t func_tailcallee; /** A function instance is a closure of the function over an instance * of the originating module. The module_instance element will be passed into * the function at runtime. */ @@ -266,7 +277,10 @@ typedef struct { } wasm_rt_funcref_t; /** Default (null) value of a funcref */ -static const wasm_rt_funcref_t wasm_rt_funcref_null_value = {NULL, NULL, NULL}; +static const wasm_rt_funcref_t wasm_rt_funcref_null_value = {NULL, + NULL, + {NULL}, + NULL}; /** The type of an external reference (opaque to WebAssembly). */ typedef void* wasm_rt_externref_t;