diff --git a/test/ast_framework/BlockStatementGraphExamples.java b/test/ast_framework/BlockStatementGraphExamples.java index a1ac006d..c400178c 100644 --- a/test/ast_framework/BlockStatementGraphExamples.java +++ b/test/ast_framework/BlockStatementGraphExamples.java @@ -153,4 +153,8 @@ void complexExample1(int x) { return x; } + + BlockStatementGraphExamples() { // First constructor + System.out.println("Inside constructor."); + } } diff --git a/test/ast_framework/test_block_statement_graph.py b/test/ast_framework/test_block_statement_graph.py index 1f527052..b463430d 100644 --- a/test/ast_framework/test_block_statement_graph.py +++ b/test/ast_framework/test_block_statement_graph.py @@ -1,44 +1,47 @@ -from typing import List, Union +from itertools import islice +from typing import List, Tuple, Union from pathlib import Path from unittest import TestCase from veniq.ast_framework.block_statement_graph import build_block_statement_graph, Block, Statement from veniq.ast_framework.block_statement_graph.constants import BlockReason -from veniq.ast_framework import AST, ASTNodeType +from veniq.ast_framework import AST, ASTNode, ASTNodeType from veniq.utils.ast_builder import build_ast class BlockStatementTestCase(TestCase): def test_single_assert_statement(self): - block_statement_graph = self._get_block_statement_graph("singleAssertStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleAssertStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ASTNodeType.METHOD_DECLARATION, BlockReason.SINGLE_BLOCK, ASTNodeType.ASSERT_STATEMENT], ) def test_single_return_statement(self): - block_statement_graph = self._get_block_statement_graph("singleReturnStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleReturnStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ASTNodeType.METHOD_DECLARATION, BlockReason.SINGLE_BLOCK, ASTNodeType.RETURN_STATEMENT], ) def test_single_statement_expression(self): - block_statement_graph = self._get_block_statement_graph("singleStatementExpression") + block_statement_graph = self._get_block_statement_graph_from_method("singleStatementExpression") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ASTNodeType.METHOD_DECLARATION, BlockReason.SINGLE_BLOCK, ASTNodeType.STATEMENT_EXPRESSION], ) def test_single_throw_statement(self): - block_statement_graph = self._get_block_statement_graph("singleThrowStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleThrowStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ASTNodeType.METHOD_DECLARATION, BlockReason.SINGLE_BLOCK, ASTNodeType.THROW_STATEMENT], ) def test_single_local_variable_declaration(self): - block_statement_graph = self._get_block_statement_graph("singleVariableDeclarationStatement") + block_statement_graph = self._get_block_statement_graph_from_method( + "singleVariableDeclarationStatement" + ) self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -49,7 +52,7 @@ def test_single_local_variable_declaration(self): ) def test_single_block_statement(self): - block_statement_graph = self._get_block_statement_graph("singleBlockStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleBlockStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -62,7 +65,7 @@ def test_single_block_statement(self): ) def test_single_do_statement(self): - block_statement_graph = self._get_block_statement_graph("singleDoStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleDoStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -75,7 +78,7 @@ def test_single_do_statement(self): ) def test_single_for_statement(self): - block_statement_graph = self._get_block_statement_graph("singleForStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleForStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -88,7 +91,7 @@ def test_single_for_statement(self): ) def test_single_synchronize_statement(self): - block_statement_graph = self._get_block_statement_graph("singleSynchronizeStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleSynchronizeStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -101,7 +104,7 @@ def test_single_synchronize_statement(self): ) def test_single_while_statement(self): - block_statement_graph = self._get_block_statement_graph("singleWhileStatement") + block_statement_graph = self._get_block_statement_graph_from_method("singleWhileStatement") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -114,7 +117,7 @@ def test_single_while_statement(self): ) def test_cycle_with_break(self): - block_statement_graph = self._get_block_statement_graph("cycleWithBreak") + block_statement_graph = self._get_block_statement_graph_from_method("cycleWithBreak") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -127,7 +130,7 @@ def test_cycle_with_break(self): ) def test_cycle_with_continue(self): - block_statement_graph = self._get_block_statement_graph("cycleWithContinue") + block_statement_graph = self._get_block_statement_graph_from_method("cycleWithContinue") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -140,7 +143,7 @@ def test_cycle_with_continue(self): ) def test_single_if_then_branch(self): - block_statement_graph = self._get_block_statement_graph("singleIfThenBranch") + block_statement_graph = self._get_block_statement_graph_from_method("singleIfThenBranch") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -153,7 +156,7 @@ def test_single_if_then_branch(self): ) def test_single_if_then_else_branches(self): - block_statement_graph = self._get_block_statement_graph("singleIfThenElseBranches") + block_statement_graph = self._get_block_statement_graph_from_method("singleIfThenElseBranches") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -168,7 +171,7 @@ def test_single_if_then_else_branches(self): ) def test_several_else_if_branches(self): - block_statement_graph = self._get_block_statement_graph("severalElseIfBranches") + block_statement_graph = self._get_block_statement_graph_from_method("severalElseIfBranches") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -185,7 +188,7 @@ def test_several_else_if_branches(self): ) def test_if_branches_without_curly_braces(self): - block_statement_graph = self._get_block_statement_graph("ifBranchingWithoutCurlyBraces") + block_statement_graph = self._get_block_statement_graph_from_method("ifBranchingWithoutCurlyBraces") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -200,7 +203,7 @@ def test_if_branches_without_curly_braces(self): ) def test_switch_branches(self): - block_statement_graph = self._get_block_statement_graph("switchBranches") + block_statement_graph = self._get_block_statement_graph_from_method("switchBranches") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -219,7 +222,7 @@ def test_switch_branches(self): ) def test_single_try_block(self): - block_statement_graph = self._get_block_statement_graph("singleTryBlock") + block_statement_graph = self._get_block_statement_graph_from_method("singleTryBlock") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -234,7 +237,7 @@ def test_single_try_block(self): ) def test_full_try_block(self): - block_statement_graph = self._get_block_statement_graph("fullTryBlock") + block_statement_graph = self._get_block_statement_graph_from_method("fullTryBlock") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -255,7 +258,7 @@ def test_full_try_block(self): ) def test_try_without_catch(self): - block_statement_graph = self._get_block_statement_graph("tryWithoutCatch") + block_statement_graph = self._get_block_statement_graph_from_method("tryWithoutCatch") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -269,7 +272,7 @@ def test_try_without_catch(self): ) def test_complex_example1(self): - block_statement_graph = self._get_block_statement_graph("complexExample1") + block_statement_graph = self._get_block_statement_graph_from_method("complexExample1") self.assertEqual( self._flatten_block_statement_graph(block_statement_graph), [ @@ -286,7 +289,18 @@ def test_complex_example1(self): ], ) - def _get_block_statement_graph(self, method_name: str) -> Block: + def test_simple_constructor(self): + block_statement_graph = self._get_block_statement_graph_from_constructor(1) + self.assertEqual( + self._flatten_block_statement_graph(block_statement_graph), + [ + ASTNodeType.CONSTRUCTOR_DECLARATION, + BlockReason.SINGLE_BLOCK, + ASTNodeType.STATEMENT_EXPRESSION, + ], + ) + + def _get_class_declaration(self) -> Tuple[str, str, ASTNode, AST]: current_directory = Path(__file__).absolute().parent filename = "BlockStatementGraphExamples.java" ast = AST.build_from_javalang(build_ast(str(current_directory / filename))) @@ -301,6 +315,11 @@ def _get_block_statement_graph(self, method_name: str) -> Block: except StopIteration: raise RuntimeError(f"Can't find class {class_name} in file {filename}") + return filename, class_name, class_declaration, ast + + def _get_block_statement_graph_from_method(self, method_name: str) -> Block: + filename, class_name, class_declaration, ast = self._get_class_declaration() + try: method_declaration = next(node for node in class_declaration.methods if node.name == method_name) except StopIteration: @@ -308,6 +327,18 @@ def _get_block_statement_graph(self, method_name: str) -> Block: return build_block_statement_graph(ast.get_subtree(method_declaration)) + def _get_block_statement_graph_from_constructor(self, constructor_index: int = 1) -> Block: + filename, class_name, class_declaration, ast = self._get_class_declaration() + + try: + method_declaration = next(islice(class_declaration.constructors, constructor_index - 1, None)) + except StopIteration: + raise ValueError( + f"Can't find {constructor_index}th constructor in class {class_name} in file {filename}" + ) + + return build_block_statement_graph(ast.get_subtree(method_declaration)) + @staticmethod def _flatten_block_statement_graph( root: Union[Block, Statement] diff --git a/veniq/ast_framework/block_statement_graph/_block_extractors.py b/veniq/ast_framework/block_statement_graph/_block_extractors.py index cf88bd4c..f0e08999 100644 --- a/veniq/ast_framework/block_statement_graph/_block_extractors.py +++ b/veniq/ast_framework/block_statement_graph/_block_extractors.py @@ -133,6 +133,7 @@ def _unwrap_block_to_statements_list( ASTNodeType.TRY_RESOURCE: _extract_blocks_from_plain_statement, # single block statements ASTNodeType.BLOCK_STATEMENT: _extract_blocks_from_single_block_statement_factory("statements"), + ASTNodeType.CONSTRUCTOR_DECLARATION: _extract_blocks_from_single_block_statement_factory("body"), ASTNodeType.DO_STATEMENT: _extract_blocks_from_single_block_statement_factory("body"), ASTNodeType.FOR_STATEMENT: _extract_blocks_from_single_block_statement_factory("body"), ASTNodeType.METHOD_DECLARATION: _extract_blocks_from_single_block_statement_factory("body"),