Skip to content

Commit

Permalink
refactor struct trait's implement
Browse files Browse the repository at this point in the history
  • Loading branch information
chloro-pn committed Jun 25, 2024
1 parent 693ba6b commit 70b8fd2
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 106 deletions.
10 changes: 2 additions & 8 deletions include/wamon/interpreter.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -228,20 +229,13 @@ class Interpreter {
return ret->IsRValue() ? std::move(ret) : ret->Clone();
}

// 注意:CallMethodByName会尝试进行struct trait -> struct的转换,而CallMethod不会,因此用户不要直接使用CallMethod接口
std::shared_ptr<Variable> CallMethodByName(std::shared_ptr<Variable> obj, const std::string& method_name,
std::vector<std::shared_ptr<Variable>>&& params) {
if (!IsStructOrEnumType(obj->GetType())) {
return CallMethod(obj, method_name, std::move(params));
}
auto type = obj->GetTypeInfo();
auto struct_def = pu_.FindStruct(type);
if (struct_def == nullptr) {
throw WamonException("CallMethodByName error, struct {} not exist", type);
}
if (struct_def->IsTrait()) {
return CallMethodByName(AsStructVariable(obj)->GetTraitObj(), method_name, std::move(params));
}
const StructDef* struct_def = AsStructVariable(obj)->GetStructDef();
auto method_def = struct_def->GetMethod(method_name);
if (method_def == nullptr) {
throw WamonException("CallMethodByName error, method {} not exist", method_name);
Expand Down
25 changes: 14 additions & 11 deletions include/wamon/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class Variable {

virtual ~Variable() = default;

std::string GetTypeInfo() const { return type_->GetTypeInfo(); }
virtual std::string GetTypeInfo() const { return type_->GetTypeInfo(); }

std::unique_ptr<Type> GetType() const { return type_->Clone(); }
virtual std::unique_ptr<Type> GetType() const { return type_->Clone(); }

void SetName(const std::string& name) { name_ = name; }

Expand Down Expand Up @@ -397,9 +397,15 @@ class StructVariable : public Variable {
public:
StructVariable(const StructDef* sd, ValueCategory vc, Interpreter& i, const std::string& name);

std::string GetTypeInfo() const override;

std::unique_ptr<Type> GetType() const override;

const StructDef* GetStructDef() const { return def_; }

std::shared_ptr<Variable> GetDataMemberByName(const std::string& name);

std::shared_ptr<Variable> GetTraitObj() const { return trait_proxy_; }
void AddDataMemberByName(const std::string& name, std::shared_ptr<Variable> data);

void UpdateDataMemberByName(const std::string& name, std::shared_ptr<Variable> data);

Expand All @@ -411,17 +417,15 @@ class StructVariable : public Variable {

~StructVariable();

void set_trait_def(const StructDef* def) { trait_def_ = def; }

private:
static void trait_construct(StructVariable* lv, StructVariable* rv);

static bool trait_compare(StructVariable* lv, StructVariable* rv);

static void trait_assign(StructVariable* lv, StructVariable* rv);

void check_trait_not_null(const char* file, int line) {
if (trait_proxy_ == nullptr) {
throw WamonException("check trait not null feiled, {} {}", file, line);
}
}

// 目前仅支持相同类型的trait间的比较和赋值
bool Compare(const std::shared_ptr<Variable>& other) override;

Expand All @@ -433,14 +437,13 @@ class StructVariable : public Variable {

private:
const StructDef* def_;
const StructDef* trait_def_;
Interpreter& ip_;
struct member {
std::string name;
std::shared_ptr<Variable> data;
};
std::vector<member> data_members_;
// only valid when def_ is struct trait.
std::shared_ptr<Variable> trait_proxy_;
};

inline StructVariable* AsStructVariable(const std::shared_ptr<Variable>& v) {
Expand Down
6 changes: 1 addition & 5 deletions src/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@ std::shared_ptr<Variable> MethodCallExpr::Calculate(Interpreter& interpreter) {
auto result = interpreter.CallMethod(v, method_name_, std::move(params));
return result;
}
auto structdef = interpreter.GetPackageUnit().FindStruct(v->GetTypeInfo());
while (structdef->IsTrait() == true) {
v = AsStructVariable(v)->GetTraitObj();
structdef = interpreter.GetPackageUnit().FindStruct(v->GetTypeInfo());
}
auto structdef = AsStructVariable(v)->GetStructDef();
auto methoddef = structdef->GetMethod(method_name_);
auto result = interpreter.CallMethod(v, methoddef, std::move(params));
return result;
Expand Down
7 changes: 4 additions & 3 deletions src/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,10 @@ static void register_buildin_operator_handles(std::unordered_map<std::string, Op
return std::make_shared<StringVariable>(enum_str, Variable::ValueCategory::RValue, "");
}
// struct to struct trait
auto v = VariableFactory(to_type, Variable::ValueCategory::RValue, "", interpreter);
v->ConstructByFields({v1});
return v;
const StructDef* trait_def = interpreter.GetPackageUnit().FindStruct(to_type->GetTypeInfo());
assert(trait_def->IsTrait() == true);
AsStructVariable(v1)->set_trait_def(trait_def);
return v1;
};
}

Expand Down
132 changes: 55 additions & 77 deletions src/variable.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "wamon/variable.h"

#include <cassert>

#include "wamon/enum_def.h"
#include "wamon/interpreter.h"
#include "wamon/method_def.h"
Expand Down Expand Up @@ -65,13 +67,15 @@ std::shared_ptr<Variable> VariableFactory(const std::unique_ptr<Type>& type, Var
std::shared_ptr<Variable> GetVoidVariable() { return std::make_shared<VoidVariable>(); }

StructVariable::StructVariable(const StructDef* sd, ValueCategory vc, Interpreter& ip, const std::string& name)
: Variable(std::make_unique<BasicType>(sd->GetStructName()), vc, name), def_(sd), ip_(ip) {}
: Variable(std::make_unique<BasicType>(sd->GetStructName()), vc, name), def_(sd), trait_def_(def_), ip_(ip) {}

std::string StructVariable::GetTypeInfo() const { return trait_def_->GetStructName(); }

std::unique_ptr<Type> StructVariable::GetType() const {
return std::make_unique<BasicType>(trait_def_->GetStructName());
}

std::shared_ptr<Variable> StructVariable::GetDataMemberByName(const std::string& name) {
if (def_->IsTrait()) {
check_trait_not_null(__FILE__, __LINE__);
return AsStructVariable(trait_proxy_)->GetDataMemberByName(name);
}
for (auto& each : data_members_) {
if (each.name == name) {
return each.data;
Expand All @@ -80,12 +84,11 @@ std::shared_ptr<Variable> StructVariable::GetDataMemberByName(const std::string&
return nullptr;
}

void StructVariable::AddDataMemberByName(const std::string& name, std::shared_ptr<Variable> data) {
data_members_.push_back({name, data});
}

void StructVariable::UpdateDataMemberByName(const std::string& name, std::shared_ptr<Variable> data) {
if (def_->IsTrait()) {
check_trait_not_null(__FILE__, __LINE__);
AsStructVariable(trait_proxy_)->UpdateDataMemberByName(name, data);
return;
}
data->ChangeTo(vc_);
auto it =
std::find_if(data_members_.begin(), data_members_.end(), [&](auto& each) -> bool { return each.name == name; });
Expand All @@ -96,22 +99,19 @@ void StructVariable::UpdateDataMemberByName(const std::string& name, std::shared
}

void StructVariable::ConstructByFields(const std::vector<std::shared_ptr<Variable>>& fields) {
assert(def_ == trait_def_);
if (def_->IsTrait()) {
if (fields[0] == nullptr) {
trait_proxy_ = nullptr;
} else if (fields[0]->IsRValue()) {
trait_proxy_ = fields[0];
} else {
trait_proxy_ = fields[0]->Clone();
}
if (trait_proxy_ != nullptr) {
trait_proxy_->ChangeTo(vc_);
}
assert(fields.size() == 1);
StructVariable* other_struct = static_cast<StructVariable*>(fields[0].get());
trait_construct(this, other_struct);
def_ = other_struct->def_;
assert(def_->IsTrait() == false);
return;
}
auto& members = def_->GetDataMembers();
if (fields.size() != members.size()) {
throw WamonException("StructVariable's ConstructByFields method error : fields.size() == {}", fields.size());
throw WamonException("StructVariable's ConstructByFields method error : fields.size() == {}, type == {}",
fields.size(), def_->GetStructName());
}
for (size_t i = 0; i < members.size(); ++i) {
if (!IsSameType(fields[i]->GetType(), members[i].second)) {
Expand All @@ -128,10 +128,7 @@ void StructVariable::ConstructByFields(const std::vector<std::shared_ptr<Variabl
}

void StructVariable::DefaultConstruct() {
if (def_->IsTrait()) {
trait_proxy_ = nullptr;
return;
}
trait_def_ = def_;
data_members_.clear();
auto& members = def_->GetDataMembers();
for (auto& each : members) {
Expand All @@ -142,30 +139,18 @@ void StructVariable::DefaultConstruct() {
}

std::shared_ptr<Variable> StructVariable::Clone() {
if (def_->IsTrait()) {
std::shared_ptr<Variable> proxy{nullptr};
if (trait_proxy_ == nullptr) {
; // do nothing
} else if (IsRValue()) {
proxy = trait_proxy_;
} else {
proxy = trait_proxy_->Clone();
}
auto ret = std::make_shared<StructVariable>(def_, ValueCategory::RValue, ip_, "");
ret->ConstructByFields({proxy});
return ret;
}
std::vector<std::shared_ptr<Variable>> variables;
std::vector<StructVariable::member> variables;
for (auto& each : data_members_) {
if (each.data->IsRValue()) {
variables.push_back(std::move(each.data));
variables.push_back({each.name, std::move(each.data)});
} else {
variables.push_back(each.data->Clone());
variables.push_back({each.name, each.data->Clone()});
}
}
// all variable in variables is rvalue now
auto ret = std::make_shared<StructVariable>(def_, ValueCategory::RValue, ip_, "");
ret->ConstructByFields(variables);
ret->data_members_ = std::move(variables);
ret->set_trait_def(trait_def_);
return ret;
}

Expand All @@ -179,16 +164,26 @@ StructVariable::~StructVariable() {
}
}

bool StructVariable::trait_compare(StructVariable* lv, StructVariable* rv) {
if (lv->trait_proxy_ == nullptr && rv->trait_proxy_ == nullptr) {
return true;
}
if (lv->trait_proxy_ == nullptr || rv->trait_proxy_ == nullptr) {
return false;
void StructVariable::trait_construct(StructVariable* lv, StructVariable* rv) {
for (const auto& each : lv->trait_def_->GetDataMembers()) {
auto tmp = rv->GetDataMemberByName(each.first);
if (rv->IsRValue()) {
assert(tmp->IsRValue());
tmp->ChangeTo(lv->vc_);
lv->AddDataMemberByName(each.first, tmp);
} else {
assert(tmp->IsRValue() == false);
auto ttmp = tmp->Clone();
ttmp->ChangeTo(lv->vc_);
lv->AddDataMemberByName(each.first, ttmp);
}
}
}

bool StructVariable::trait_compare(StructVariable* lv, StructVariable* rv) {
// 类型分析阶段保证
assert(lv->def_ == rv->def_);
for (auto& each : lv->def_->GetDataMembers()) {
assert(lv->trait_def_ == rv->trait_def_);
for (const auto& each : lv->trait_def_->GetDataMembers()) {
auto tmp = lv->GetDataMemberByName(each.first);
auto tmp2 = rv->GetDataMemberByName(each.first);
if (tmp->Compare(tmp2) == false) {
Expand All @@ -200,19 +195,8 @@ bool StructVariable::trait_compare(StructVariable* lv, StructVariable* rv) {

void StructVariable::trait_assign(StructVariable* lv, StructVariable* rv) {
// 类型分析阶段保证
assert(lv->def_ == rv->def_);
if (lv->trait_proxy_ == nullptr && rv->trait_proxy_ == nullptr) {
return;
}
if (rv->trait_proxy_ == nullptr) {
lv->trait_proxy_ = nullptr;
return;
}
if (lv->trait_proxy_ == nullptr) {
lv->trait_proxy_ = rv->IsRValue() ? rv->trait_proxy_ : rv->trait_proxy_->Clone();
lv->trait_proxy_->ChangeTo(lv->vc_);
}
for (auto& each : lv->def_->GetDataMembers()) {
assert(lv->trait_def_ == rv->trait_def_);
for (auto& each : lv->trait_def_->GetDataMembers()) {
auto tmp = rv->GetDataMemberByName(each.first);
if (rv->IsRValue()) {
assert(tmp->IsRValue());
Expand All @@ -229,11 +213,11 @@ void StructVariable::trait_assign(StructVariable* lv, StructVariable* rv) {

// 目前仅支持相同类型的trait间的比较和赋值
bool StructVariable::Compare(const std::shared_ptr<Variable>& other) {
check_compare_type_match(other);
if (def_->IsTrait()) {
return trait_compare(this, AsStructVariable(other));
}
StructVariable* other_struct = static_cast<StructVariable*>(other.get());
assert(trait_def_ == other_struct->trait_def_);
if (trait_def_->IsTrait()) {
return trait_compare(this, other_struct);
}
for (size_t index = 0; index < data_members_.size(); ++index) {
if (data_members_[index].data->Compare(other_struct->data_members_[index].data) == false) {
return false;
Expand All @@ -243,11 +227,11 @@ bool StructVariable::Compare(const std::shared_ptr<Variable>& other) {
}

void StructVariable::Assign(const std::shared_ptr<Variable>& other) {
check_compare_type_match(other);
if (def_->IsTrait()) {
return trait_assign(this, AsStructVariable(other));
}
StructVariable* other_struct = static_cast<StructVariable*>(other.get());
assert(trait_def_ == other_struct->trait_def_);
if (trait_def_->IsTrait()) {
return trait_assign(this, other_struct);
}
if (other_struct->IsRValue()) {
for (size_t index = 0; index < data_members_.size(); ++index) {
data_members_[index].data = std::move(other_struct->data_members_[index].data);
Expand All @@ -263,12 +247,6 @@ void StructVariable::Assign(const std::shared_ptr<Variable>& other) {

void StructVariable::ChangeTo(ValueCategory vc) {
vc_ = vc;
if (def_->IsTrait()) {
if (trait_proxy_ != nullptr) {
trait_proxy_->ChangeTo(vc);
}
return;
}
for (auto& each : data_members_) {
assert(each.data != nullptr);
each.data->ChangeTo(vc);
Expand Down
5 changes: 3 additions & 2 deletions test/interpreter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ TEST(interpreter, trait) {
let v0 : have_age_and_name = move t2;
let v1 : have_age = (move t1);
let v2 : have_age = (move v0);
let v1 : have_age = move t1;
let v2 : have_age = move v0;
func test() -> bool {
return v1 == v2;
Expand All @@ -479,6 +479,7 @@ TEST(interpreter, trait) {
wamon::Interpreter interpreter(pu);
auto v = interpreter.FindVariableById("main$v1");
EXPECT_EQ(wamon::AsStructVariable(v)->GetDataMemberByName("a")->GetTypeInfo(), "int");
EXPECT_EQ(wamon::AsStructVariable(v)->GetStructDef()->GetStructName(), "main$s1");
auto ret = interpreter.CallFunctionByName("main$test", {});
EXPECT_EQ(ret->GetTypeInfo(), "bool");
EXPECT_EQ(wamon::AsBoolVariable(ret)->GetValue(), true);
Expand Down

0 comments on commit 70b8fd2

Please sign in to comment.