Skip to content

Commit

Permalink
refactor register_cpp_function
Browse files Browse the repository at this point in the history
  • Loading branch information
chloro-pn committed Jun 14, 2024
1 parent 76f4adf commit 905e0c2
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 28 deletions.
3 changes: 2 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
* 返回类型为void时的return语句优化
* [done] 支持析构函数
* 支持内置函数assert
* 支持 enum to string
* 支持 enum to string
* 内置函数放置在wamon包名称中
6 changes: 3 additions & 3 deletions example/register_cpp_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ int main() {
package main;
func call_cpp_function(string s) -> int {
return call my_cpp_func:(s);
return call wamon::my_cpp_func:(s);
}
let test : f(() -> int) = my_type_cpp_func;
let test : f(() -> int) = wamon::my_type_cpp_func;
)";

wamon::Scanner scanner;
auto tokens = scanner.Scan(script);
wamon::PackageUnit package_unit = wamon::Parse(tokens);
package_unit = wamon::MergePackageUnits(std::move(package_unit));
package_unit.RegisterCppFunctions("my_cpp_func", my_cpp_func_check, my_cpp_func);
package_unit.RegisterCppFunctions("my_type_cpp_func", wamon::TypeFactory<int()>::Get(), my_type_cpp_func);

package_unit = wamon::MergePackageUnits(std::move(package_unit));
wamon::TypeChecker type_checker(package_unit);

std::string reason;
Expand Down
23 changes: 23 additions & 0 deletions include/wamon/builtin_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class BuiltinFunctions {

BuiltinFunctions();

static void RegisterWamonBuiltinFunction(BuiltinFunctions&);

BuiltinFunctions(BuiltinFunctions&&) = default;

BuiltinFunctions& operator=(BuiltinFunctions&&) = default;
Expand Down Expand Up @@ -67,6 +69,27 @@ class BuiltinFunctions {
builtin_types_[name] = std::move(type);
}

void Merge(BuiltinFunctions&& other) {
for (auto& each : other.builtin_handles_) {
if (builtin_handles_.count(each.first) > 0) {
throw WamonException("BuiltinFunctions.Merge error, duplicate handle {}", each.first);
}
builtin_handles_[each.first] = std::move(each.second);
}
for (auto& each : other.builtin_checks_) {
if (builtin_checks_.count(each.first) > 0) {
throw WamonException("BuiltinFunctions.Merge error, duplicate check {}", each.first);
}
builtin_checks_[each.first] = std::move(each.second);
}
for (auto& each : other.builtin_types_) {
if (builtin_types_.count(each.first) > 0) {
throw WamonException("BuiltinFunctions.Merge error, duplicate typed name {}", each.first);
}
builtin_types_[each.first] = std::move(each.second);
}
}

private:
std::unordered_map<std::string, HandleType> builtin_handles_;
std::unordered_map<std::string, CheckType> builtin_checks_;
Expand Down
30 changes: 28 additions & 2 deletions include/wamon/package_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "wamon/function_def.h"
#include "wamon/method_def.h"
#include "wamon/operator_def.h"
#include "wamon/prepared_package_name.h"
#include "wamon/struct_def.h"
#include "wamon/type.h"

Expand All @@ -24,6 +25,7 @@ class PackageUnit {
static PackageUnit _MergePackageUnits(std::vector<PackageUnit>&& packages);

std::string CreateUniqueLambdaName() {
MergedFlagCheck("CreateUniqueLambdaName");
auto ret = "__lambda_" + std::to_string(lambda_count_) + GetName();
lambda_count_ += 1;
return ret;
Expand All @@ -34,9 +36,16 @@ class PackageUnit {
PackageUnit(PackageUnit&&) = default;
PackageUnit& operator=(PackageUnit&&) = default;

void SetName(const std::string& name) { package_name_ = name; }
void SetName(const std::string& name) {
MergedFlagCheck("SetName");
if (IsPreparedPakageName(name)) {
throw WamonException("PackageUnit.SetName invalid , {}", name);
}
package_name_ = name;
}

void SetImportPackage(const std::vector<std::string>& import_packages) {
MergedFlagCheck("SetImportPackage");
import_packages_ = import_packages;
package_imports_[package_name_] = import_packages;
}
Expand All @@ -46,6 +55,7 @@ class PackageUnit {
const std::vector<std::string>& GetImportPackage() const { return import_packages_; }

void AddVarDef(std::unique_ptr<VariableDefineStmt>&& vd) {
MergedFlagCheck("AddVarDef");
if (std::find_if(var_define_.begin(), var_define_.end(), [&vd](const auto& v) -> bool {
return vd->GetVarName() == v->GetVarName();
}) != var_define_.end()) {
Expand All @@ -55,6 +65,7 @@ class PackageUnit {
}

void AddFuncDef(std::unique_ptr<FunctionDef>&& func_def) {
MergedFlagCheck("AddFuncDef");
auto name = func_def->GetFunctionName();
if (funcs_.find(name) != funcs_.end()) {
throw WamonException("duplicate func {}", name);
Expand All @@ -63,6 +74,7 @@ class PackageUnit {
}

void AddStructDef(std::unique_ptr<StructDef>&& struct_def) {
MergedFlagCheck("AddStructDef");
auto name = struct_def->GetStructName();
if (structs_.find(name) != structs_.end()) {
throw WamonException("duplicate struct {}", name);
Expand All @@ -71,6 +83,7 @@ class PackageUnit {
}

void AddEnumDef(std::unique_ptr<EnumDef>&& enum_def) {
MergedFlagCheck("AddEnumDef");
auto name = enum_def->GetEnumName();
if (enum_def->GetEnumItems().empty()) {
throw WamonException("AddEnumDef error, empty enum {}", name);
Expand All @@ -82,6 +95,7 @@ class PackageUnit {
}

void AddMethod(const std::string& type_name, std::unique_ptr<methods_def>&& methods) {
MergedFlagCheck("AddMethod");
assert(type_name.empty() == false);
if (structs_.find(type_name) == structs_.end()) {
throw WamonException("add method error, invalid type : {}", type_name);
Expand All @@ -90,6 +104,7 @@ class PackageUnit {
}

void AddLambdaFunction(const std::string& lambda_name, std::unique_ptr<FunctionDef>&& lambda) {
MergedFlagCheck("AddLambdaFunction");
assert(LambdaExpr::IsLambdaName(lambda_name));
if (funcs_.find(lambda_name) != funcs_.end()) {
throw WamonException("PackageUnit.AddLambdaFunction error, duplicate function name {}", lambda_name);
Expand Down Expand Up @@ -152,6 +167,7 @@ class PackageUnit {
BuiltinFunctions& GetBuiltinFunctions() { return builtin_functions_; }

void AddPackageImports(const std::string& package, const std::vector<std::string>& imports) {
MergedFlagCheck("AddPackageImports");
if (package_imports_.count(package) > 0) {
auto& tmp = package_imports_[package];
for (auto& each : imports) {
Expand All @@ -171,7 +187,8 @@ class PackageUnit {
}

void RegisterCppFunctions(const std::string& name, BuiltinFunctions::CheckType ct, BuiltinFunctions::HandleType ht) {
GetBuiltinFunctions().Register(name, std::move(ct), std::move(ht));
MergedFlagCheck("RegisterCppFunctions");
GetBuiltinFunctions().Register(GetName() + "$" + name, std::move(ct), std::move(ht));
}

void RegisterCppFunctions(const std::string& name, std::unique_ptr<Type> func_type, BuiltinFunctions::HandleType ht);
Expand All @@ -185,6 +202,15 @@ class PackageUnit {
void SetCurrentParsingPackage(const std::string& cpp) { current_parsing_package_ = cpp; }

private:
// PackageUnit有两种类型,只有Merge得到的PackageUnit才能够进行类型分析以及交付给解释器执行。
bool merged = false;

void MergedFlagCheck(const std::string& info) {
if (merged == true) {
throw WamonException("MergedFlagCheck error, from {}", info);
}
}

std::string package_name_;
std::vector<std::string> import_packages_;
// 包作用域的变量定义语句
Expand Down
9 changes: 9 additions & 0 deletions include/wamon/prepared_package_name.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <string>

namespace wamon {

bool IsPreparedPakageName(const std::string& package_name);

}
2 changes: 1 addition & 1 deletion include/wamon/static_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class StaticAnalyzer {
return fd->GetType();
}
// 只有设定类型的注册函数能够被找到
auto ftype = GetPackageUnit().GetBuiltinFunctions().GetType(GetIdFromIdent(name));
auto ftype = GetPackageUnit().GetBuiltinFunctions().GetType(name);
if (ftype != nullptr) {
type = IdExpr::Type::BuiltinFunc;
return ftype;
Expand Down
2 changes: 1 addition & 1 deletion src/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ std::shared_ptr<Variable> FuncCallExpr::Calculate(Interpreter& interpreter) {
if (type == FuncCallType::FUNC) {
func_name = dynamic_cast<IdExpr*>(caller_.get())->GenerateIdent();
} else if (type == FuncCallType::BUILT_IN_FUNC) {
func_name = dynamic_cast<IdExpr*>(caller_.get())->GetId();
func_name = dynamic_cast<IdExpr*>(caller_.get())->GenerateIdent();
}
if (type == FuncCallType::FUNC) {
auto funcdef = interpreter.GetPackageUnit().FindFunction(func_name);
Expand Down
26 changes: 15 additions & 11 deletions src/builtin_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ static auto _to_string(Interpreter& ip, std::vector<std::shared_ptr<Variable>>&&
return std::make_shared<StringVariable>(result, wamon::Variable::ValueCategory::RValue, "");
}

static void register_builtin_handles(std::unordered_map<std::string, BuiltinFunctions::HandleType>& handles) {
handles["print"] = _print;
handles["to_string"] = _to_string;
handles["context_stack"] = _context_stack;
static void register_builtin_handles(const std::string& prefix,
std::unordered_map<std::string, BuiltinFunctions::HandleType>& handles) {
handles[prefix + "print"] = _print;
handles[prefix + "to_string"] = _to_string;
handles[prefix + "context_stack"] = _context_stack;
}

static auto _print_check(const std::vector<std::unique_ptr<Type>>& params_type) -> std::unique_ptr<Type> {
Expand All @@ -82,15 +83,18 @@ static auto _to_string_check(const std::vector<std::unique_ptr<Type>>& params_ty
throw WamonException("to_string type_check error, type {} cant not be to_string", type->GetTypeInfo());
}

static void register_builtin_checks(std::unordered_map<std::string, BuiltinFunctions::CheckType>& checks) {
checks["print"] = _print_check;
checks["to_string"] = _to_string_check;
checks["context_stack"] = _context_stack_check;
static void register_builtin_checks(const std::string& prefix,
std::unordered_map<std::string, BuiltinFunctions::CheckType>& checks) {
checks[prefix + "print"] = _print_check;
checks[prefix + "to_string"] = _to_string_check;
checks[prefix + "context_stack"] = _context_stack_check;
}

BuiltinFunctions::BuiltinFunctions() {
register_builtin_checks(builtin_checks_);
register_builtin_handles(builtin_handles_);
BuiltinFunctions::BuiltinFunctions() {}

void BuiltinFunctions::RegisterWamonBuiltinFunction(BuiltinFunctions& obj) {
register_builtin_checks("wamon$", obj.builtin_checks_);
register_builtin_handles("wamon$", obj.builtin_handles_);
}

std::unique_ptr<Type> BuiltinFunctions::TypeCheck(const std::string& name,
Expand Down
7 changes: 6 additions & 1 deletion src/package_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@ PackageUnit PackageUnit::_MergePackageUnits(std::vector<PackageUnit>&& packages)
struct_define->second->SetStructName(it->GetName() + "$" + struct_name);
result.AddStructDef(std::move(struct_define->second));
}

result.builtin_functions_.Merge(std::move(it->builtin_functions_));
}
BuiltinFunctions::RegisterWamonBuiltinFunction(result.builtin_functions_);
// 不需要更新 lambda_count_,因为merge之后的PackageUnit不会再进行解析了。
result.merged = true;
return result;
}

void PackageUnit::RegisterCppFunctions(const std::string& name, std::unique_ptr<Type> func_type,
BuiltinFunctions::HandleType ht) {
MergedFlagCheck("RegisterCppFunctions");
if (IsFuncType(func_type) == false) {
throw WamonException("RegisterCppFunctions error, {} have non-function type : {}", name, func_type->GetTypeInfo());
}
Expand All @@ -67,7 +72,7 @@ void PackageUnit::RegisterCppFunctions(const std::string& name, std::unique_ptr<
return GetReturnType(func_type);
};
RegisterCppFunctions(name, std::move(check_f), std::move(ht));
GetBuiltinFunctions().SetTypeForFunction(name, std::move(func_type));
GetBuiltinFunctions().SetTypeForFunction(GetName() + "$" + name, std::move(func_type));
}

} // namespace wamon
6 changes: 6 additions & 0 deletions src/parsing_package.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
#include "wamon/parsing_package.h"

#include <algorithm>

#include "wamon/exception.h"
#include "wamon/package_unit.h"
#include "wamon/prepared_package_name.h"

namespace wamon {

void AssertInImportListOrThrow(PackageUnit& pu, const std::string& package_name) {
if (package_name == pu.GetCurrentParsingPackage()) {
return;
}
if (IsPreparedPakageName(package_name)) {
return;
}
for (auto& each : pu.GetCurrentParsingImports()) {
if (package_name == each) {
return;
Expand Down
27 changes: 27 additions & 0 deletions src/prepared_package_name.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "wamon/prepared_package_name.h"

namespace wamon {

namespace detail {

template <typename T, size_t N>
constexpr size_t GetArrayLength(T (&)[N]) {
return N;
}

} // namespace detail

bool IsPreparedPakageName(const std::string& package_name) {
static std::string prepared_package_name[] = {
"wamon",
};

for (size_t i = 0; i < detail::GetArrayLength(prepared_package_name); ++i) {
if (package_name == prepared_package_name[i]) {
return true;
}
}
return false;
}

} // namespace wamon
6 changes: 4 additions & 2 deletions src/type_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -562,13 +562,15 @@ std::unique_ptr<Type> CheckAndGetFuncReturnType(const TypeChecker& tc, const Fun
*/
std::unique_ptr<Type> CheckParamTypeAndGetResultTypeForFunction(const TypeChecker& tc, FuncCallExpr* call_expr) {
auto id_expr = dynamic_cast<IdExpr*>(call_expr->caller_.get());
if (id_expr != nullptr && tc.GetStaticAnalyzer().GetPackageUnit().GetBuiltinFunctions().Find(id_expr->GetId())) {
if (id_expr != nullptr &&
tc.GetStaticAnalyzer().GetPackageUnit().GetBuiltinFunctions().Find(id_expr->GenerateIdent())) {
call_expr->type = FuncCallExpr::FuncCallType::BUILT_IN_FUNC;
std::vector<std::unique_ptr<Type>> param_types;
for (auto& each : call_expr->parameters_) {
param_types.push_back(tc.GetExpressionType(each.get()));
}
return tc.GetStaticAnalyzer().GetPackageUnit().GetBuiltinFunctions().TypeCheck(id_expr->GetId(), param_types);
return tc.GetStaticAnalyzer().GetPackageUnit().GetBuiltinFunctions().TypeCheck(id_expr->GenerateIdent(),
param_types);
}
// would calcualte id_expr.type_
std::unique_ptr<Type> find_type = tc.GetExpressionType(call_expr->caller_.get());
Expand Down
4 changes: 3 additions & 1 deletion test/builtin_function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TEST(builtin_function, context_stack) {
func func3() -> void {
{
let lmd: f(()->void) = lambda [] () -> void { cs = call context_stack:(); return; };
let lmd: f(()->void) = lambda [] () -> void { cs = call wamon::context_stack:(); return; };
call lmd:();
}
return;
Expand All @@ -48,6 +48,8 @@ TEST(builtin_function, context_stack) {
auto tokens = scan.Scan(script);
auto pu = Parse(tokens);
pu = MergePackageUnits(std::move(pu));
EXPECT_THROW(pu.RegisterCppFunctions("mytest", BuiltinFunctions::CheckType(), BuiltinFunctions::HandleType()),
WamonException);

TypeChecker tc(pu);
std::string reason;
Expand Down
10 changes: 5 additions & 5 deletions test/interpreter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ TEST(interpreter, builtin_function) {
wamon::Scanner scan;
std::string str = R"(
package main;
let v : string = call to_string:(20);
let v2 : string = call to_string:(true);
let v3 : string = call to_string:(3.35);
let v4 : string = call to_string:(0X41);
let v : string = call wamon::to_string:(20);
let v2 : string = call wamon::to_string:(true);
let v3 : string = call wamon::to_string:(3.35);
let v4 : string = call wamon::to_string:(0X41);
)";
wamon::PackageUnit pu;
auto tokens = scan.Scan(str);
Expand Down Expand Up @@ -707,7 +707,6 @@ TEST(interpreter, register_cpp_function) {
wamon::PackageUnit pu;
auto tokens = scan.Scan(str);
pu = wamon::Parse(tokens);
pu = wamon::MergePackageUnits(std::move(pu));

pu.RegisterCppFunctions(
"func111",
Expand All @@ -726,6 +725,7 @@ TEST(interpreter, register_cpp_function) {
return std::make_shared<wamon::IntVariable>(static_cast<int>(len), wamon::Variable::ValueCategory::RValue, "");
});

pu = wamon::MergePackageUnits(std::move(pu));
wamon::TypeChecker tc(pu);
std::string reason;
bool succ = tc.CheckAll(reason);
Expand Down

0 comments on commit 905e0c2

Please sign in to comment.