Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AST framework and tests #653

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aibolit/ast_framework/_auxiliary_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ class ASTNodeReference(NamedTuple):
},
ASTNodeType.TYPE_PARAMETER: {"extends", "name"},
ASTNodeType.TYPE: {"dimensions", "name"},
ASTNodeType.UNKNOWN: set(), # unknown nodes have no fields
ASTNodeType.VARIABLE_DECLARATION: {
"annotations",
"declarators",
Expand Down
10 changes: 9 additions & 1 deletion aibolit/ast_framework/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from collections import namedtuple
from collections import namedtuple, defaultdict
from itertools import islice, repeat, chain

from deprecated import deprecated # type: ignore
Expand Down Expand Up @@ -127,6 +127,12 @@ def traverse(
elif edge_type == "reverse":
on_node_leaving(ASTNode(self.tree, destination))

def create_fake_node(self) -> ASTNode:
fake_nodes_qty = self._fake_nodes_qty_per_graph[self.tree]
self._fake_nodes_qty_per_graph[self.tree] += 1
new_fake_node_id = -(fake_nodes_qty + 1)
return ASTNode(self.tree, new_fake_node_id)

@deprecated(reason='Use ASTNode functionality instead.')
def children_with_type(self, node: int, child_type: ASTNodeType) -> Iterator[int]:
'''
Expand Down Expand Up @@ -379,3 +385,5 @@ def _create_reference_to_node(javalang_node: Node,
return ASTNodeReference(javalang_node_to_index_map[javalang_node])

_UNKNOWN_NODE_TYPE = -1

_fake_nodes_qty_per_graph: Dict[DiGraph, int] = defaultdict(lambda: 0)
61 changes: 42 additions & 19 deletions aibolit/ast_framework/ast_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from inspect import getmembers
from typing import Any, List, Iterator, Optional

from networkx import DiGraph, dfs_preorder_nodes # type: ignore
Expand All @@ -41,11 +42,17 @@ def __init__(self, graph: DiGraph, node_index: int):

@property
def children(self) -> Iterator["ASTNode"]:
if self.is_fake:
return iter(())

for child_index in self._graph.succ[self._node_index]:
yield ASTNode(self._graph, child_index)

@property
def parent(self) -> Optional["ASTNode"]:
if self.is_fake:
return None

try:
parent_index = next(self._graph.predecessors(self._node_index))
return ASTNode(self._graph, parent_index)
Expand All @@ -56,8 +63,15 @@ def parent(self) -> Optional["ASTNode"]:
def node_index(self) -> int:
return self._node_index

@property
def is_fake(self) -> bool:
return self._node_index < 0

@cached_property
def line(self) -> int:
if self.is_fake:
return -1

line = self._get_line(self._node_index)
if line is not None:
return line
Expand Down Expand Up @@ -88,12 +102,17 @@ def line(self) -> int:
)

def __getattr__(self, attribute_name: str):
if self.is_fake:
return None

node_type = self._get_type(self._node_index)
javalang_fields = attributes_by_node_type[node_type]
computed_fields = computed_fields_registry.get_fields(node_type)
if attribute_name not in common_attributes and \
attribute_name not in javalang_fields and \
attribute_name not in computed_fields:
if (
attribute_name not in common_attributes
and attribute_name not in javalang_fields
and attribute_name not in computed_fields
):
raise AttributeError(
"Failed to retrieve property. "
f"'{node_type}' node does not have '{attribute_name}' attribute."
Expand All @@ -115,26 +134,28 @@ def __getattr__(self, attribute_name: str):
return attribute

def __dir__(self) -> List[str]:
attribute_names = self._get_public_fixed_interface()
if self.is_fake:
return attribute_names

node_type = self._get_type(self._node_index)
return ASTNode._public_fixed_interface + \
list(common_attributes) + \
list(attributes_by_node_type[node_type]) + \
list(computed_fields_registry.get_fields(node_type).keys())
return (
attribute_names
+ list(common_attributes)
+ list(attributes_by_node_type[node_type])
+ list(computed_fields_registry.get_fields(node_type).keys())
)

def __str__(self) -> str:
text_representation = f"node index: {self._node_index}"
node_type = self._get_type(self._node_index)
for attribute_name in sorted(
common_attributes | attributes_by_node_type[node_type]
):
for attribute_name in sorted(common_attributes | attributes_by_node_type[node_type]):
attribute_value = self.__getattr__(attribute_name)

if isinstance(attribute_value, ASTNode):
attribute_representation = repr(attribute_value)
elif isinstance(attribute_value, str) and "\n" in attribute_value:
attribute_representation = "\n\t" + attribute_value.replace(
"\n", "\n\t"
)
attribute_representation = "\n\t" + attribute_value.replace("\n", "\n\t")
else:
attribute_representation = str(attribute_value)

Expand All @@ -143,7 +164,7 @@ def __str__(self) -> str:
return text_representation

def __repr__(self) -> str:
return f"<ASTNode node_type: {self._get_type(self._node_index)}, node_index: {self._node_index}>"
return f"<ASTNode node_type: {self.node_type}, node_index: {self._node_index}>"

def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTNode):
Expand All @@ -155,9 +176,7 @@ def __eq__(self, other: object) -> bool:
def __hash__(self):
return hash(self._node_index)

def _replace_references_with_nodes(
self, list_with_references: List[Any]
) -> List[Any]:
def _replace_references_with_nodes(self, list_with_references: List[Any]) -> List[Any]:
list_with_nodes: List[Any] = []
for item in list_with_references:
if isinstance(item, ASTNodeReference):
Expand All @@ -179,6 +198,9 @@ def _create_node_from_reference(self, reference: ASTNodeReference) -> "ASTNode":
return ASTNode(self._graph, reference.node_index)

def _get_type(self, node_index: int) -> ASTNodeType:
if self.is_fake:
return ASTNodeType.UNKNOWN

return self._graph.nodes[node_index]["node_type"]

def _get_line(self, node_index: int) -> Optional[int]:
Expand All @@ -188,5 +210,6 @@ def _get_parent(self, node_index: int) -> Optional[int]:
# there is maximum one parent in a tree
return next(self._graph.predecessors(node_index), None)

# names of methods and properties, which is not generated dynamically
_public_fixed_interface = ["children", "node_index", "line"]
@classmethod
def _get_public_fixed_interface(cls) -> List[str]:
return [name for name, _ in getmembers(cls) if not name.startswith("_")]
3 changes: 3 additions & 0 deletions aibolit/ast_framework/block_statement_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .block import Block # noqa: F401
from .statement import Statement # noqa: F401
from .builder import build_block_statement_graph # noqa: F401
146 changes: 146 additions & 0 deletions aibolit/ast_framework/block_statement_graph/_block_extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Callable, Dict, List, Optional, NamedTuple, Union

from ..ast_node import ASTNode
from ..ast_node_type import ASTNodeType
from .constants import BlockReason


class BlockInfo(NamedTuple):
reason: BlockReason
statements: List[ASTNode]
origin_statement: Optional[ASTNode] = None


def extract_blocks_from_statement(statement: ASTNode) -> List[BlockInfo]:
try:
return _block_extractors[statement.node_type](statement)
except KeyError:
raise NotImplementedError(f"Node {statement.node_type} is not supported.")


def _extract_blocks_from_plain_statement(statement: ASTNode) -> List[BlockInfo]:
return []


def _extract_blocks_from_single_block_statement_factory(
field_name: str,
) -> Callable[[ASTNode], List[BlockInfo]]:
def extract_blocks_from_single_block_statement(statement: ASTNode) -> List[BlockInfo]:
return [
BlockInfo(
reason=BlockReason.SINGLE_BLOCK,
statements=_unwrap_block_to_statements_list(getattr(statement, field_name)),
)
]

return extract_blocks_from_single_block_statement


def _extract_blocks_from_if_branching(statement: ASTNode) -> List[BlockInfo]:
block_infos: List[BlockInfo] = []

while statement is not None and statement.node_type == ASTNodeType.IF_STATEMENT:
block_infos.append(
BlockInfo(
reason=BlockReason.THEN_BRANCH,
statements=_unwrap_block_to_statements_list(statement.then_statement),
origin_statement=statement
)
)

statement = statement.else_statement

if statement is not None:
block_infos.append(
BlockInfo(
reason=BlockReason.ELSE_BRANCH,
statements=_unwrap_block_to_statements_list(statement)
)
)

return block_infos


def _extract_blocks_from_switch_branching(statement: ASTNode) -> List[BlockInfo]:
return [BlockInfo(
reason=BlockReason.SINGLE_BLOCK,
statements=[
switch_statement
for switch_case in statement.cases
for switch_statement in switch_case.statements
]
)]


def _extract_blocks_from_try_statement(statement: ASTNode) -> List[BlockInfo]:
block_infos: List[BlockInfo] = []

if statement.resources is not None:
block_infos.append(
BlockInfo(
reason=BlockReason.TRY_RESOURCES,
statements=statement.resources
)
)

block_infos.append(
BlockInfo(
reason=BlockReason.TRY_BLOCK,
statements=_unwrap_block_to_statements_list(statement.block)
)
)

if statement.catches is not None:
for catch_clause in statement.catches:
block_infos.append(
BlockInfo(
reason=BlockReason.CATCH_BLOCK,
statements=_unwrap_block_to_statements_list(catch_clause.block)
)
)

if statement.finally_block is not None:
block_infos.append(
BlockInfo(
reason=BlockReason.FINALLY_BLOCK,
statements=_unwrap_block_to_statements_list(statement.finally_block)
)
)

return block_infos


def _unwrap_block_to_statements_list(
block_statement_or_statement_list: Union[ASTNode, List[ASTNode]]
) -> List[ASTNode]:
if isinstance(block_statement_or_statement_list, ASTNode):
if block_statement_or_statement_list.node_type == ASTNodeType.BLOCK_STATEMENT:
return block_statement_or_statement_list.statements
else:
return [block_statement_or_statement_list]

return block_statement_or_statement_list


_block_extractors: Dict[ASTNodeType, Callable[[ASTNode], List[BlockInfo]]] = {
# plain statements
ASTNodeType.ASSERT_STATEMENT: _extract_blocks_from_plain_statement,
ASTNodeType.BREAK_STATEMENT: _extract_blocks_from_plain_statement,
ASTNodeType.CONTINUE_STATEMENT: _extract_blocks_from_plain_statement,
ASTNodeType.RETURN_STATEMENT: _extract_blocks_from_plain_statement,
ASTNodeType.STATEMENT_EXPRESSION: _extract_blocks_from_plain_statement,
ASTNodeType.THROW_STATEMENT: _extract_blocks_from_plain_statement,
ASTNodeType.LOCAL_VARIABLE_DECLARATION: _extract_blocks_from_plain_statement,
ASTNodeType.TRY_RESOURCE: _extract_blocks_from_plain_statement,
# single block statements
ASTNodeType.BLOCK_STATEMENT: _extract_blocks_from_single_block_statement_factory("statements"),
ASTNodeType.DO_STATEMENT: _extract_blocks_from_single_block_statement_factory("body"),
ASTNodeType.FOR_STATEMENT: _extract_blocks_from_single_block_statement_factory("body"),
ASTNodeType.METHOD_DECLARATION: _extract_blocks_from_single_block_statement_factory("body"),
ASTNodeType.SYNCHRONIZED_STATEMENT: _extract_blocks_from_single_block_statement_factory("block"),
ASTNodeType.WHILE_STATEMENT: _extract_blocks_from_single_block_statement_factory("body"),
ASTNodeType.SWITCH_STATEMENT: _extract_blocks_from_switch_branching,
# multi block statements
ASTNodeType.IF_STATEMENT: _extract_blocks_from_if_branching,
ASTNodeType.TRY_STATEMENT: _extract_blocks_from_try_statement,
}
54 changes: 54 additions & 0 deletions aibolit/ast_framework/block_statement_graph/_nodes_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Callable, Union
from networkx import DiGraph, dfs_labeled_edges # type: ignore

from .statement import Statement
from .block import Block
from .constants import NodeType, NodeId, NODE, BLOCK_REASON

TraverseCallback = Callable[[Union[Block, Statement]], None]


class NodesFactory:
@staticmethod
def create_statement_node(graph: DiGraph, id: NodeId) -> Statement:
return Statement(graph, id, NodesFactory.create_block_node, NodesFactory._traverse_graph)

@staticmethod
def create_block_node(graph: DiGraph, id: NodeId) -> Block:
return Block(graph, id, NodesFactory.create_statement_node, NodesFactory._traverse_graph)

@staticmethod
def _detect_and_create_node(graph: DiGraph, id: NodeId) -> Union[Block, Statement]:
node_type = NodesFactory._detect_node_type(graph, id)
if node_type == NodeType.Block:
return NodesFactory.create_block_node(graph, id)
elif node_type == NodeType.Statement:
return NodesFactory.create_statement_node(graph, id)
else:
raise ValueError(f"Unexpected node type {node_type}.")

@staticmethod
def _detect_node_type(graph: DiGraph, id: NodeId) -> NodeType:
node_attributes = graph.nodes(data=True)[id]
if NODE in node_attributes:
return NodeType.Statement
elif BLOCK_REASON in node_attributes:
return NodeType.Block
else:
raise ValueError(f"Cannot identify node with attributes {node_attributes}")

@staticmethod
def _traverse_graph(
graph: DiGraph,
start_node_id: NodeId,
on_node_entering: TraverseCallback,
on_node_leaving: TraverseCallback = lambda _: None,
) -> None:
for _, destination_id, edge_type in dfs_labeled_edges(graph, start_node_id):
destination_node = NodesFactory._detect_and_create_node(graph, destination_id)
if edge_type == "forward":
on_node_entering(destination_node)
elif edge_type == "reverse":
on_node_leaving(destination_node)
else:
raise RuntimeError(f"Unexpected edge type {edge_type}.")
Loading