diff --git a/parser.y b/parser.y index f4c240a..a32b955 100644 --- a/parser.y +++ b/parser.y @@ -380,7 +380,7 @@ COMMENT { $$ = new CommentNode(std::move($1)); } ; prop_decl: -access_modifier optional_static type IDENTIFIER { $$ = new PropDeclNode(std::move($3), std::move($4), $1, $2); } +access_modifier optional_static var_decl { $$ = new PropDeclNode($3, $1, $2); } ; method_def: diff --git a/src/ast.h b/src/ast.h index f369265..0171dd2 100644 --- a/src/ast.h +++ b/src/ast.h @@ -614,21 +614,19 @@ namespace X { }; class PropDeclNode : public Node { - Type type; - std::string name; + DeclNode *decl; AccessModifier accessModifier; bool isStatic; public: - PropDeclNode(Type type, std::string name, AccessModifier accessModifier = AccessModifier::PUBLIC, bool isStatic = false) : - Node(NodeKind::PropDecl), type(std::move(type)), name(std::move(name)), accessModifier(accessModifier), isStatic(isStatic) {} + PropDeclNode(DeclNode *decl, AccessModifier accessModifier = AccessModifier::PUBLIC, bool isStatic = false) : + Node(NodeKind::PropDecl), decl(decl), accessModifier(accessModifier), isStatic(isStatic) {} void print(Pipes::PrintAst &astPrinter, int level = 0) override; llvm::Value *gen(Codegen::Codegen &codegen) override; Type infer(Pipes::TypeInferrer &typeInferrer) override; - const Type &getType() const { return type; } - const std::string &getName() const { return name; } + DeclNode *getDecl() const { return decl; } AccessModifier getAccessModifier() const { return accessModifier; } bool getIsStatic() const { return isStatic; } diff --git a/src/codegen/class.cpp b/src/codegen/class.cpp index 3bc451c..1b62c40 100644 --- a/src/codegen/class.cpp +++ b/src/codegen/class.cpp @@ -88,6 +88,13 @@ namespace X::Codegen { auto obj = newObj(classDecl.llvmType); + initVtable(obj, classDecl); + + auto initFnName = Mangler::mangleHiddenMethod(Mangler::mangleClass(classDecl.name), INIT_FN_NAME); + if (auto initFn = module.getFunction(initFnName)) { + builder.CreateCall(initFn, {obj}); + } + try { callMethod(obj, classDecl.type, CONSTRUCTOR_FN_NAME, node->getArgs()); } catch (const MethodNotFoundException &e) { @@ -96,8 +103,6 @@ namespace X::Codegen { } } - initVtable(obj, classDecl); - return obj; } @@ -399,9 +404,7 @@ namespace X::Codegen { return builder.CreateCall(callee, llvmArgs); } - std::tuple Codegen::findMethod( - llvm::Value *obj, const Type &objType, const std::string &methodName - ) { + std::tuple Codegen::findMethod(llvm::Value *obj, const Type &objType, const std::string &methodName) { if (objType.isOneOf(Type::TypeID::STRING, Type::TypeID::ARRAY)) { const auto &name = Mangler::mangleInternalMethod(getClassName(objType), methodName); auto fn = module.getFunction(name); diff --git a/src/codegen/codegen.h b/src/codegen/codegen.h index 65610dc..27e2e25 100644 --- a/src/codegen/codegen.h +++ b/src/codegen/codegen.h @@ -70,6 +70,7 @@ namespace X::Codegen { bool isAbstract = false; llvm::StructType *vtableType = nullptr; GC::Metadata *meta; + bool needInit = false; }; struct InterfaceDecl { @@ -152,6 +153,10 @@ namespace X::Codegen { void declFuncs(TopStatementListNode *node); void declGlobals(TopStatementListNode *node); + void genGlobal(DeclNode *node); + void genClassInit(ClassNode *node, const ClassDecl &classDecl); + void genStaticPropInit(PropDeclNode *prop, ClassNode *klass); + llvm::Type *mapType(const Type &type); llvm::Constant *getDefaultValue(const Type &type); /// differs from getDefaultValue because getDefaultValue returns constant and createDefaultValue can emit instructions diff --git a/src/codegen/decl.cpp b/src/codegen/decl.cpp index 34aec45..2745787 100644 --- a/src/codegen/decl.cpp +++ b/src/codegen/decl.cpp @@ -57,6 +57,7 @@ namespace X::Codegen { propPos++; classDecl.parent = const_cast(&parentClassDecl); pointerList = parentClassDecl.meta->pointerList; + classDecl.needInit = parentClassDecl.needInit; } auto it = compilerRuntime.virtualMethods.find(name); @@ -68,8 +69,10 @@ namespace X::Codegen { } for (auto prop: klassNode->getProps()) { - auto &propName = prop->getName(); - auto type = mapType(prop->getType()); + auto &propName = prop->getDecl()->getName(); + auto &type = prop->getDecl()->getType(); + auto llvmType = mapType(type); + if (prop->getIsStatic()) { // check if static prop already declared here, because module.getOrInsertGlobal could return ConstantExpr // (if prop will be redeclared with different type) @@ -77,15 +80,18 @@ namespace X::Codegen { throw PropAlreadyDeclaredException(name, propName); } const auto &mangledPropName = Mangler::mangleStaticProp(mangledName, propName); - auto global = llvm::cast(module.getOrInsertGlobal(mangledPropName, type)); - global->setInitializer(getDefaultValue(prop->getType())); - classDecl.staticProps[propName] = {global, prop->getType(), prop->getAccessModifier()}; + auto global = llvm::cast(module.getOrInsertGlobal(mangledPropName, llvmType)); + classDecl.staticProps[propName] = {global, type, prop->getAccessModifier()}; } else { - props.push_back(type); - auto [_, inserted] = classDecl.props.try_emplace(propName, prop->getType(), propPos++, prop->getAccessModifier()); + props.push_back(llvmType); + auto [_, inserted] = classDecl.props.try_emplace(propName, type, propPos++, prop->getAccessModifier()); if (!inserted) { throw PropAlreadyDeclaredException(name, propName); } + + if (prop->getDecl()->getExpr()) { + classDecl.needInit = true; + } } } @@ -102,6 +108,10 @@ namespace X::Codegen { } classDecl.meta = gc.addMeta(GC::NodeType::CLASS, std::move(pointerList)); + + if (classDecl.needInit) { + genClassInit(klassNode, classDecl); + } } } @@ -149,12 +159,6 @@ namespace X::Codegen { } void Codegen::declGlobals(TopStatementListNode *node) { - auto &globals = node->getGlobals(); - - if (globals.empty()) { - return; - } - // create init function auto initFn = llvm::Function::Create( llvm::FunctionType::get(builder.getVoidTy(), {}, false), @@ -167,30 +171,104 @@ namespace X::Codegen { varScopes.emplace_back(); auto &vars = varScopes.back(); - for (auto decl: globals) { - auto &name = decl->getName(); - if (vars.contains(name)) { - throw VarAlreadyExistsException(name); + // declare global variables + for (auto decl: node->getGlobals()) { + genGlobal(decl); + } + + // set initializers for static props + for (auto klass: node->getClasses()) { + for (auto prop: klass->getProps()) { + if (prop->getIsStatic()) { + genStaticPropInit(prop, klass); + } } + } - auto &type = decl->getType(); - auto llvmType = mapType(type); - auto global = llvm::cast(module.getOrInsertGlobal(name, llvmType)); - - auto value = decl->getExpr() ? - decl->getExpr()->gen(*this) : - createDefaultValue(type); - - if (auto constValue = llvm::dyn_cast(value)) { - global->setInitializer(constValue); - } else { - global->setInitializer(getDefaultValue(type)); - builder.CreateStore(value, global); + builder.CreateRetVoid(); + } + + void Codegen::genGlobal(DeclNode *node) { + auto &vars = varScopes.back(); + + auto &name = node->getName(); + if (vars.contains(name)) { + throw VarAlreadyExistsException(name); + } + + auto &type = node->getType(); + auto llvmType = mapType(type); + auto global = llvm::cast(module.getOrInsertGlobal(name, llvmType)); + + auto value = node->getExpr() ? + castTo(node->getExpr()->gen(*this), node->getExpr()->getType(), type) : + createDefaultValue(type); + + if (auto constValue = llvm::dyn_cast(value)) { + global->setInitializer(constValue); + } else { + global->setInitializer(getDefaultValue(type)); + builder.CreateStore(value, global); + } + + vars[name] = {global, type}; + + gcAddGlobalRoot(global, type); + } + + void Codegen::genStaticPropInit(PropDeclNode *prop, ClassNode *klass) { + auto decl = prop->getDecl(); + auto &type = decl->getType(); + auto llvmType = mapType(decl->getType()); + auto mangledClassName = Mangler::mangleClass(klass->getName()); + auto mangledPropName = Mangler::mangleStaticProp(mangledClassName, decl->getName()); + auto global = llvm::cast(module.getOrInsertGlobal(mangledPropName, llvmType)); + + auto value = decl->getExpr() ? + castTo(decl->getExpr()->gen(*this), decl->getExpr()->getType(), type) : + createDefaultValue(type); + + if (auto constValue = llvm::dyn_cast(value)) { + global->setInitializer(constValue); + } else { + global->setInitializer(getDefaultValue(type)); + builder.CreateStore(value, global); + } + + gcAddGlobalRoot(global, type); + } + + void Codegen::genClassInit(ClassNode *node, const ClassDecl &classDecl) { + auto mangledName = Mangler::mangleClass(classDecl.name); + + // create init function + auto initFn = llvm::Function::Create( + llvm::FunctionType::get(builder.getVoidTy(), {builder.getPtrTy()}, false), + llvm::Function::ExternalLinkage, + Mangler::mangleHiddenMethod(mangledName, INIT_FN_NAME), + module + ); + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", initFn)); + + auto initFnThis = initFn->getArg(0); + + if (classDecl.parent) { + if (auto parentInitFn = module.getFunction(Mangler::mangleHiddenMethod(Mangler::mangleClass(classDecl.parent->name), INIT_FN_NAME))) { + builder.CreateCall(parentInitFn, {initFnThis}); } + } - vars[name] = {global, type}; + for (auto prop: node->getProps()) { + if (prop->getIsStatic() || !prop->getDecl()->getExpr()) { + continue; + } + + auto decl = prop->getDecl(); + auto &type = decl->getType(); + auto value = castTo(decl->getExpr()->gen(*this), decl->getExpr()->getType(), type); + auto ptr = builder.CreateStructGEP(classDecl.llvmType, initFnThis, classDecl.props.at(decl->getName()).pos); - gcAddGlobalRoot(global, type); + builder.CreateStore(value, ptr); } builder.CreateRetVoid(); diff --git a/src/compiler.cpp b/src/compiler.cpp index a34ca58..5c06cf2 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -16,7 +16,7 @@ namespace X { (Pipeline{}) .pipe(Pipes::ParseCode(code)) -// .pipe(new Pipes::PrintAst()) +// .pipe(Pipes::PrintAst()) .pipe(Pipes::CheckInterfaces(compilerRuntime)) .pipe(Pipes::CheckAbstractClasses()) .pipe(Pipes::CheckVirtualMethods(compilerRuntime)) diff --git a/src/mangler.h b/src/mangler.h index b845afe..0f48255 100644 --- a/src/mangler.h +++ b/src/mangler.h @@ -16,7 +16,11 @@ namespace X { } static std::string mangleMethod(const std::string &mangledClassName, const std::string &methodName) { - return mangledClassName + "_" + methodName; + return mangledClassName + '_' + methodName; + } + + static std::string mangleHiddenMethod(const std::string &mangledClassName, const std::string &methodName) { + return mangledClassName + '.' + methodName; } static std::string mangleInternalMethod(const std::string &mangledClassName, const std::string &methodName) { @@ -24,7 +28,7 @@ namespace X { } static std::string mangleStaticProp(const std::string &mangledClassName, const std::string &propName) { - return mangledClassName + "_" + propName; + return mangledClassName + '_' + propName; } static std::string mangleInternalFunction(const std::string &fnName) { diff --git a/src/pipes/print_ast.cpp b/src/pipes/print_ast.cpp index 81b18d8..b6e7e46 100644 --- a/src/pipes/print_ast.cpp +++ b/src/pipes/print_ast.cpp @@ -209,7 +209,7 @@ namespace X::Pipes { std::cout << "static "; } - std::cout << node->getType() << ' ' << node->getName() << std::endl; + node->getDecl()->print(*this, level); } void PrintAst::printNode(MethodDefNode *node, int level) { diff --git a/src/pipes/type_inferrer.cpp b/src/pipes/type_inferrer.cpp index 794bfc5..86f0501 100644 --- a/src/pipes/type_inferrer.cpp +++ b/src/pipes/type_inferrer.cpp @@ -60,10 +60,14 @@ namespace X::Pipes { classProps[name] = classProps[klass->getParent()]; } + auto &props = classProps[name]; + for (auto prop: klass->getProps()) { - auto &type = prop->getType(); - checkLvalueTypeIsValid(type); - classProps[name][prop->getName()] = {type, prop->getIsStatic()}; + auto decl = prop->getDecl(); + + checkDecl(decl); + + props[decl->getName()] = {decl->getType(), prop->getIsStatic()}; } } } @@ -261,35 +265,9 @@ namespace X::Pipes { } Type TypeInferrer::infer(DeclNode *node) { - auto &type = node->getType(); - - if (type.is(Type::TypeID::AUTO)) { - // infer expr type and change decl type accordingly - if (!node->getExpr()) { - throw InvalidTypeException(); - } - - auto exprType = node->getExpr()->infer(*this); - - checkLvalueTypeIsValid(exprType); - - node->setType(exprType); - varScopes.back()[node->getName()] = exprType; - - return Type::voidTy(); - } - - checkLvalueTypeIsValid(type); - - if (node->getExpr()) { - auto exprType = node->getExpr()->infer(*this); - - if (!canCastTo(exprType, type)) { - throw InvalidTypeException(); - } - } + checkDecl(node); - varScopes.back()[node->getName()] = type; + varScopes.back()[node->getName()] = node->getType(); return Type::voidTy(); } @@ -674,6 +652,35 @@ namespace X::Pipes { checkTypeIsValid(type); } + void TypeInferrer::checkDecl(DeclNode *node) { + auto &type = node->getType(); + + if (type.is(Type::TypeID::AUTO)) { + // infer expr type and change decl type accordingly + if (!node->getExpr()) { + throw InvalidTypeException(); + } + + auto exprType = node->getExpr()->infer(*this); + + checkLvalueTypeIsValid(exprType); + + node->setType(exprType); + + return; + } + + checkLvalueTypeIsValid(type); + + if (node->getExpr()) { + auto exprType = node->getExpr()->infer(*this); + + if (!canCastTo(exprType, type)) { + throw InvalidTypeException(); + } + } + } + void TypeInferrer::checkFnCall(const FnType &fnType, const ExprList &args) { if (fnType.args.size() != args.size()) { throw TypeInferrerException("call args mismatch"); diff --git a/src/pipes/type_inferrer.h b/src/pipes/type_inferrer.h index 5edae0d..8be3530 100644 --- a/src/pipes/type_inferrer.h +++ b/src/pipes/type_inferrer.h @@ -79,6 +79,7 @@ namespace X::Pipes { void checkTypeIsValid(const Type &type) const; void checkLvalueTypeIsValid(const Type &type) const; void checkArgTypeIsValid(const Type &type) const; + void checkDecl(DeclNode *node); void checkFnCall(const FnType &fnType, const ExprList &args); const Type &getMethodReturnType(FnDeclNode *fnDecl, const std::string &className) const; const Type &getMethodReturnType(FnDefNode *fnDef, const std::string &className) const; diff --git a/tests/class_test.cpp b/tests/class_test.cpp index 6a2dff8..7a624b7 100644 --- a/tests/class_test.cpp +++ b/tests/class_test.cpp @@ -583,3 +583,53 @@ class Bar { } )code", "a"); } + +TEST_F(ClassTest, staticWithInitializer) { + checkProgram(R"code( +class Foo { + public static int a = 1 +} + +class Bar { + public static auto b = Foo::a + 2 +} + +fn main() void { + println(Bar::b) +} +)code", "3"); +} + +TEST_F(ClassTest, propWithInitializer) { + checkProgram(R"code( +class Foo { + public int a = 1 +} + +class Bar extends Foo { + public int b = 2 +} + +fn main() void { + auto bar = new Bar() + + println(bar.a + bar.b) +} +)code", "3"); +} + +TEST_F(ClassTest, parentWithInit) { + checkProgram(R"code( +class Foo { + public int a = 1 +} + +class Bar extends Foo {} + +fn main() void { + auto bar = new Bar() + + println(bar.a) +} +)code", "1"); +} diff --git a/tests/statement_test.cpp b/tests/statement_test.cpp index 8fbe01d..2a06f92 100644 --- a/tests/statement_test.cpp +++ b/tests/statement_test.cpp @@ -57,3 +57,15 @@ fn main() void { )code", R"output(1 2)output"); } + +TEST_F(StatementTest, globals2) { + checkProgram(R"code( +int foo = 1 +float bar = 2 +auto fooBar = foo + bar + +fn main() void { + println(fooBar) +} +)code", "3"); +}