diff --git a/jaclang/compiler/absyntree.py b/jaclang/compiler/absyntree.py index 21a9a6382..3243397e4 100644 --- a/jaclang/compiler/absyntree.py +++ b/jaclang/compiler/absyntree.py @@ -271,12 +271,16 @@ def __init__( body: Sequence[ElementStmt | String | EmptyToken], is_imported: bool, kid: Sequence[AstNode], + impl_mod: Optional[Module] = None, + test_mod: Optional[Module] = None, ) -> None: """Initialize whole program node.""" self.name = name self.source = source self.body = body self.is_imported = is_imported + self.impl_mod = impl_mod + self.test_mod = test_mod self.mod_deps: dict[str, Module] = {} AstNode.__init__(self, kid=kid) AstDocNode.__init__(self, doc=doc) diff --git a/jaclang/compiler/passes/main/import_pass.py b/jaclang/compiler/passes/main/import_pass.py index b753ebdec..8621e6e37 100644 --- a/jaclang/compiler/passes/main/import_pass.py +++ b/jaclang/compiler/passes/main/import_pass.py @@ -27,6 +27,7 @@ def before_pass(self) -> None: def enter_module(self, node: ast.Module) -> None: """Run Importer.""" self.cur_node = node + self.annex_impl(node) self.terminate() # Turns off auto traversal for deliberate traversal self.run_again = True while self.run_again: @@ -35,25 +36,42 @@ def enter_module(self, node: ast.Module) -> None: for i in all_imports: if i.lang.tag.value == "jac" and not i.sub_module: self.run_again = True - mod = ( - self.import_module( - node=i, - mod_path=node.loc.mod_path, - ) - if i.lang.tag.value == "jac" - else self.import_py_module(node=i, mod_path=node.loc.mod_path) + mod = self.import_module( + node=i, + mod_path=node.loc.mod_path, ) if not mod: self.run_again = False continue + self.annex_impl(mod) i.sub_module = mod i.add_kids_right([mod], pos_update=False) # elif i.lang.tag.value == "py": # self.import_py_module(node=i, mod_path=node.loc.mod_path) self.enter_import(i) SubNodeTabPass(prior=self, input_ir=node) + self.annex_impl(node) node.mod_deps = self.import_table + def annex_impl(self, node: ast.Module) -> None: + """Annex impl and test modules.""" + if not node.loc.mod_path: + self.error("Module has no path") + if node.loc.mod_path.endswith(".jac") and path.exists( + f"{node.loc.mod_path[:-4]}.impl.jac" + ): + mod = self.import_mod_from_file(f"{node.loc.mod_path[:-4]}.impl.jac") + if mod: + node.impl_mod = mod + node.add_kids_right([mod], pos_update=False) + if node.loc.mod_path.endswith(".jac") and path.exists( + f"{node.loc.mod_path[:-4]}.test.jac" + ): + mod = self.import_mod_from_file(f"{node.loc.mod_path[:-4]}.test.jac") + if mod: + node.test_mod = mod + node.add_kids_right([mod], pos_update=False) + def enter_import(self, node: ast.Import) -> None: """Sub objects. @@ -73,19 +91,22 @@ def enter_import(self, node: ast.Import) -> None: def import_module(self, node: ast.Import, mod_path: str) -> ast.Module | None: """Import a module.""" - from jaclang.compiler.transpiler import jac_file_to_pass - from jaclang.compiler.passes.main import SubNodeTabPass - self.cur_node = node # impacts error reporting target = import_target_to_relative_path( node.path.path_str, path.dirname(node.loc.mod_path) ) + return self.import_mod_from_file(target) - if target in self.import_table: - return self.import_table[target] + def import_mod_from_file(self, target: str) -> ast.Module | None: + """Import a module from a file.""" + from jaclang.compiler.transpiler import jac_file_to_pass + from jaclang.compiler.passes.main import SubNodeTabPass if not path.exists(target): - self.error(f"Could not find module {target}", node_override=node) + self.error(f"Could not find module {target}") + return None + if target in self.import_table: + return self.import_table[target] try: mod_pass = jac_file_to_pass(file_path=target, target=SubNodeTabPass) self.errors_had += mod_pass.errors_had @@ -99,9 +120,7 @@ def import_module(self, node: ast.Import, mod_path: str) -> ast.Module | None: mod.is_imported = True return mod else: - self.error( - f"Module {target} is not a valid Jac module.", node_override=node - ) + self.error(f"Module {target} is not a valid Jac module.") return None def import_py_module(self, node: ast.Import, mod_path: str) -> Optional[ast.Module]: