Skip to content

Commit

Permalink
interactive: add diff to operation count table (#1870)
Browse files Browse the repository at this point in the history
  • Loading branch information
dshaaban01 authored Dec 18, 2023
1 parent 71ea828 commit 705eb8c
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 34 deletions.
59 changes: 56 additions & 3 deletions tests/interactive/test_pass_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
IntegerAttr,
ModuleOp,
)
from xdsl.interactive.pass_metrics import count_number_of_operations
from xdsl.interactive.pass_metrics import (
count_number_of_operations,
get_diff_operation_count,
)
from xdsl.ir import Block, MLContext, Region
from xdsl.parser import Parser
from xdsl.tools.command_line_tool import get_all_dialects
Expand Down Expand Up @@ -54,12 +57,62 @@ def test_operation_counter_with_parsing_text():
module = parser.parse_module()

expected_res = {
"func.func": 1,
"arith.constant": 1,
"arith.muli": 1,
"func.return": 1,
"builtin.module": 1,
"func.func": 1,
"func.return": 1,
}

res = count_number_of_operations(module)
assert res == expected_res


def test_get_diff_operation_count():
# get input module
input_text = """builtin.module {
func.func @hello(%n : index) -> index {
%two = arith.constant 2 : index
%res = arith.muli %n, %two : index
func.return %res : index
}
}
"""

ctx = MLContext(True)
for dialect in get_all_dialects():
ctx.load_dialect(dialect)
parser = Parser(ctx, input_text)
input_module = parser.parse_module()

# get output module
output_text = """builtin.module {
func.func @hello(%n : index) -> index {
%two = riscv.li 2 : () -> !riscv.reg<>
%two_1 = builtin.unrealized_conversion_cast %two : !riscv.reg<> to index
%res = builtin.unrealized_conversion_cast %n : index to !riscv.reg<>
%res_1 = builtin.unrealized_conversion_cast %two_1 : index to !riscv.reg<>
%res_2 = riscv.mul %res, %res_1 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%res_3 = builtin.unrealized_conversion_cast %res_2 : !riscv.reg<> to index
func.return %res_3 : index
}
}
"""
parser = Parser(ctx.clone(), output_text)
output_module = parser.parse_module()

expected_diff_res: tuple[tuple[str, int, str], ...] = (
("arith.constant", 0, "-1"),
("arith.muli", 0, "-1"),
("builtin.module", 1, "="),
("builtin.unrealized_conversion_cast", 4, "+4"),
("func.func", 1, "="),
("func.return", 1, "="),
("riscv.li", 1, "+1"),
("riscv.mul", 1, "+1"),
)

assert expected_diff_res == get_diff_operation_count(
tuple(count_number_of_operations(input_module).items()),
tuple(count_number_of_operations(output_module).items()),
)
77 changes: 47 additions & 30 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
from xdsl.dialects.builtin import ModuleOp
from xdsl.interactive.add_arguments_screen import AddArguments
from xdsl.interactive.load_file_screen import LoadFile
from xdsl.interactive.pass_metrics import count_number_of_operations
from xdsl.interactive.pass_metrics import (
count_number_of_operations,
get_diff_operation_count,
)
from xdsl.ir import MLContext
from xdsl.parser import Parser
from xdsl.passes import ModulePass, PipelinePass, get_pass_argument_names_and_types
Expand Down Expand Up @@ -136,14 +139,23 @@ class InputApp(App[None]):
"""ListView displaying the passes available to apply."""

input_operation_count_tuple = reactive(tuple[tuple[str, int], ...])
"""Saves the operation name and count of the input text area in a dictionary."""
output_operation_count_tuple = reactive(tuple[tuple[str, int], ...])
"""Saves the operation name and count of the output text area in a dictionary."""
"""
Saves the operation name and count of the input text area in a reactive tuple of
tuples.
"""
diff_operation_count_tuple = reactive(tuple[tuple[str, int, str], ...])
"""
Saves the diff of the input_operation_count_tuple and the output_operation_count_tuple
in a reactive tuple of tuples.
"""

input_operation_count_datatable: DataTable[str | int]
"""DataTable displaying the operation names and counts of the input text area."""
output_operation_count_datatable: DataTable[str | int]
"""DataTable displaying the operation names and counts of the output text area."""
diff_operation_count_datatable: DataTable[str | int]
"""
DataTable displaying the diff of operation names and counts of the input and output
text areas.
"""

def __init__(self):
self.input_text_area = TextArea(id="input")
Expand All @@ -153,8 +165,8 @@ def __init__(self):
self.input_operation_count_datatable = DataTable(
id="input_operation_count_datatable"
)
self.output_operation_count_datatable = DataTable(
id="output_operation_count_datatable"
self.diff_operation_count_datatable = DataTable(
id="diff_operation_count_datatable"
)

super().__init__()
Expand Down Expand Up @@ -196,7 +208,7 @@ def compose(self) -> ComposeResult:
yield self.output_text_area
yield Button("Copy Output", id="copy_output_button")
with ScrollableContainer(id="output_ops_container"):
yield self.output_operation_count_datatable
yield self.diff_operation_count_datatable
yield Footer()

def on_mount(self) -> None:
Expand Down Expand Up @@ -224,8 +236,8 @@ def on_mount(self) -> None:
self.input_operation_count_datatable.add_columns("Operation", "Count")
self.input_operation_count_datatable.zebra_stripes = True

self.output_operation_count_datatable.add_columns("Operation", "Count")
self.output_operation_count_datatable.zebra_stripes = True
self.diff_operation_count_datatable.add_columns("Operation", "Count", "Diff")
self.diff_operation_count_datatable.zebra_stripes = True

def compute_available_pass_list(self) -> tuple[type[ModulePass], ...]:
"""
Expand Down Expand Up @@ -376,7 +388,7 @@ def watch_current_module(self):
output_text = output_stream.getvalue()

self.output_text_area.load_text(output_text)
self.update_output_operation_count_tuple()
self.update_operation_count_diff_tuple()

def get_query_string(self) -> str:
"""
Expand All @@ -394,45 +406,50 @@ def update_input_operation_count_tuple(self, input_module: ModuleOp) -> None:
Function that updates the input_operation_datatable to display the operation
names and counts in the input text area.
"""
# sort tuples alphabetically by operation name
self.input_operation_count_tuple = tuple(
count_number_of_operations(input_module).items()
sorted(count_number_of_operations(input_module).items())
)

def watch_input_operation_count_tuple(self) -> None:
"""
Function called when the reactive variable input_operation_count_tuple changes - updates the
Input DataTable.
"""
# clear datatable and add input_operation_count_tuple to DataTable
self.input_operation_count_datatable.clear()
for k, v in self.input_operation_count_tuple:
self.input_operation_count_datatable.add_row(k, v)

self.update_output_operation_count_tuple()
self.input_operation_count_datatable.add_rows(self.input_operation_count_tuple)
self.update_operation_count_diff_tuple()

def update_output_operation_count_tuple(self) -> None:
def update_operation_count_diff_tuple(self) -> None:
"""
Function that updates the output_operation_datatable to display the operation
names and counts in the output text area. It also displays the diff of the input
and output datatable.
Function that updates the diff_operation_count_tuple to calculate the diff
of the input and output operation counts.
"""
match self.current_module:
case None:
self.output_operation_count_tuple = ()
output_operation_count_tuple = ()
case Exception():
self.output_operation_count_tuple = ()
output_operation_count_tuple = ()
case ModuleOp():
self.output_operation_count_tuple = tuple(
count_number_of_operations(self.current_module).items()
# sort tuples alphabetically by operation name
output_operation_count_tuple = tuple(
(k, v)
for (k, v) in sorted(
count_number_of_operations(self.current_module).items()
)
)
self.diff_operation_count_tuple = get_diff_operation_count(
self.input_operation_count_tuple, output_operation_count_tuple
)

def watch_output_operation_count_tuple(self) -> None:
def watch_diff_operation_count_tuple(self) -> None:
"""
Function called when the reactive variable output_operation_count_tuple changes
Function called when the reactive variable diff_operation_count_tuple changes
- updates the Output DataTable.
"""
self.output_operation_count_datatable.clear()
for k, v in self.output_operation_count_tuple:
self.output_operation_count_datatable.add_row(k, v)
self.diff_operation_count_datatable.clear()
self.diff_operation_count_datatable.add_rows(self.diff_operation_count_tuple)

def action_toggle_dark(self) -> None:
"""An action to toggle dark mode."""
Expand Down
2 changes: 1 addition & 1 deletion xdsl/interactive/app.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
}

# DataTable
#output_operation_count_datatable{
#diff_operation_count_datatable{
width: auto;
height: auto;
}
Expand Down
31 changes: 31 additions & 0 deletions xdsl/interactive/pass_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,34 @@ def count_number_of_operations(module: ModuleOp) -> dict[str, int]:
occurences of each Operation in the ModuleOp.
"""
return Counter(op.name for op in module.walk())


def get_diff_operation_count(
input_operation_count_tuple: tuple[tuple[str, int], ...],
output_operation_count_tuple: tuple[tuple[str, int], ...],
) -> tuple[tuple[str, int, str], ...]:
"""
Function returning a tuple of tuples containing the diff of the input and output
operation name and count.
"""
input_op_count_dict = dict(input_operation_count_tuple)
output_op_count_dict = dict(output_operation_count_tuple)
all_keys = {*input_op_count_dict, *output_op_count_dict}

res: dict[str, tuple[int, str]] = {}
for k in all_keys:
input_count = input_op_count_dict.get(k, 0)
output_count = output_op_count_dict.get(k, 0)
diff = output_count - input_count

# convert diff to string
if diff == 0:
diff_str = "="
elif diff > 0:
diff_str = f"+{diff}"
else:
diff_str = str(diff)

res[k] = (output_count, diff_str)

return tuple((k, v0, v1) for (k, (v0, v1)) in sorted(res.items()))

0 comments on commit 705eb8c

Please sign in to comment.