diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 417fd908d..8038b739d 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -71,7 +71,7 @@ jobs: - name: Pull Test Data run: git lfs pull - name: Run tests - run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml + run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junitxml junit.xml env: CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" @@ -80,12 +80,11 @@ jobs: uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - - name: Upload Test Results - if: always() - uses: actions/upload-artifact@v3 + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 with: - name: Test Results (${{ matrix.name }}-${{ matrix.os }}) - path: pytest.xml + token: ${{ secrets.CODECOV_TOKEN }} - name: Upload torchlib error reports if: always() uses: actions/upload-artifact@v3 @@ -161,23 +160,3 @@ jobs: echo "Update readme by running `python docs/update_readme.py`" exit 1 fi - - publish-test-results: - name: "Publish Tests Results to Github" - needs: test - runs-on: ubuntu-latest - permissions: - checks: write - # only needed unless run with comment_mode: off - pull-requests: write - if: always() - steps: - - name: Download Artifacts - uses: actions/download-artifact@v3 - with: - path: artifacts - - - name: Publish Test Results - uses: EnricoMi/publish-unit-test-result-action@v2 - with: - files: "artifacts/**/*.xml" diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e6f446a6d..2ca22c7e4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3564,7 +3564,7 @@ def aten_flipud(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::floor") +@torch_op("aten::floor", traceable=True) def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """floor(Tensor self) -> Tensor""" @@ -3578,13 +3578,22 @@ def python_math_floor(self: TFloatOrBFloat16) -> TInt: return op.Cast(floor, to=INT64.dtype) -@torch_op(("aten::floor_divide", "_operator::floordiv")) +@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True) def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor""" return op.Floor(op.Div(self, other)) +@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True) +def aten_floor_divide_int(self: TInt, other: TInt) -> TInt: + """floor_divide(Tensor self, Tensor other) -> Tensor""" + + # We implement floor_divide only for positive inputs (using integer division) + # because that is the usual intended case and is the most efficient. + return op.Div(self, other) + + def aten_fmax(self: TensorType, other: TensorType) -> TensorType: """fmax(Tensor self, Tensor other) -> Tensor""" @@ -3597,14 +3606,14 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar")) +@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"), traceable=True) def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8: """fmod.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Mod(self, other, fmod=1) -@torch_op("aten::frac") +@torch_op("aten::frac", traceable=True) def aten_frac(self: TFloat) -> TFloat: """frac(Tensor self) -> Tensor diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 28bdb655b..61dbf5f0b 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -15,6 +15,7 @@ import abc import contextlib import dataclasses +import heapq import math import mmap import os @@ -670,6 +671,13 @@ def tobytes(self) -> bytes: length = self._length or self.nbytes return self.raw[offset : offset + length] + def release(self) -> None: + """Delete all references to the memory buffer and close the memory-mapped file.""" + self._array = None + if self.raw is not None: + self.raw.close() + self.raw = None + @property def metadata_props(self) -> dict[str, str]: if self._metadata_props is None: @@ -1977,8 +1985,103 @@ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None self._nodes.insert_before(node, new_nodes) def sort(self) -> None: - """Topologically sort the nodes in the graph.""" - raise NotImplementedError("Not implemented yet") + """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time. + + This sort is stable. It preserves the original order as much as possible. + + Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort + + Raises: + ValueError: If the graph contains a cycle, making topological sorting impossible. + """ + # Obtain all nodes from the graph and its subgraphs for sorting + nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self)) + # Store the sorted nodes of each subgraph + sorted_nodes_by_graph: dict[Graph, list[Node]] = { + graph: [] for graph in {node.graph for node in nodes if node.graph is not None} + } + # TODO: Explain why we need to store direct predecessors and children and why + # we only need to store the direct ones + + # The depth of a node is defined as the number of direct children it has + node_depth: dict[Node, int] = dict.fromkeys(nodes, 0) + # Direct predecessors of a node + node_predecessors: dict[Node, list[Node]] = {node: [] for node in nodes} + # Store the negative index of the nodes because heapq is a min heap and we + # want to pop the node with largest index value first, effectively turning + # it to a max heap + neg_node_index: dict[Node, int] = {node: -i for i, node in enumerate(nodes)} + + def add_predecessor(child: Node, predecessor: Node | None) -> None: + """Add a predecessor of a node, and increment the depth of the predecessor.""" + if predecessor is None: + return + node_predecessors[child].append(predecessor) + node_depth[predecessor] += 1 + + # 1. Build the direct predecessors of each node and the depth of each node + # for sorting topolocally using Kahn's algorithm. + # Note that when a node contains graph attributes (aka. has subgraphs), + # we consider all nodes in the subgraphs *predecessors* of this node. This + # way we ensure the implicit dependencies of the subgraphs are captured + # as predecessors of the node. + for node in nodes: + # All producers of input values are considered as direct predecessors. + for input_value in node.inputs: + if input_value is None: + continue + predecessor_node = input_value.producer() + add_predecessor(node, predecessor_node) + # All nodes in attribute graphs are considered as direct predecessors. + for attr in node.attributes.values(): + if not isinstance(attr, Attr): + continue + # A nice thing about this algorithm is that we only need to record + # direct predecessors. This continues to be true even with subgraphs. + # When a node in a subgraph (a) contains its own subgraphs (b), the + # node in subgraphs (b) are guranteed to appear before the node + # in (a). + if attr.type == _enums.AttributeType.GRAPH: + for predecessor_node in attr.value: + add_predecessor(node, predecessor_node) + elif attr.type == _enums.AttributeType.GRAPHS: + for attribute_graph in attr.value: + for predecessor_node in attribute_graph: + add_predecessor(node, predecessor_node) + + # 2. Priority Queue: Track nodes with zero direct children in a priority queue, + # using NEGATIVE original index for ordering. + # This ensures nodes appearing LATER in the original order are processed EARLIER. + # We get REVERSED topological order of each subgraph. + priority_queue: list[tuple[int, Node]] = [ + (neg_node_index[node], node) for node in nodes if node_depth[node] == 0 + ] + heapq.heapify(priority_queue) + + # 3. Topological Sort: + num_of_sorted_nodes = 0 + while priority_queue: + # Pop the node with the most negative index and add it to the sorted nodes by subgraph. + _, current_node = heapq.heappop(priority_queue) + assert current_node.graph is not None + sorted_nodes_by_graph[current_node.graph].append(current_node) + num_of_sorted_nodes += 1 + # Decrement the depth of its predecessors. If any predecessor node has zero direct children, push it into the queue. + for predecessor_node in node_predecessors[current_node]: + node_depth[predecessor_node] -= 1 + if node_depth[predecessor_node] == 0: + heapq.heappush( + priority_queue, (neg_node_index[predecessor_node], predecessor_node) + ) + + # 4. Cycle Check: Ensure all nodes are processed. If not, raise a ValueError indicating a cycle. + if num_of_sorted_nodes != len(nodes): + raise ValueError("Graph contains a cycle, topological sort is not possible.") + + # 5. Reverse: Reverse the sorted nodes of each subgraph to get the topological order. + for graph, sorted_nodes in sorted_nodes_by_graph.items(): + # The graph container ensures all the nodes are unique so we can safely extend + graph.extend(reversed(sorted_nodes)) # End of mutation methods @@ -2261,8 +2364,8 @@ def __str__(self) -> str: model_version={self.model_version!r}, >""" graph_text = str(self.graph) - functions_text = ",\n\n".join(str(func) for func in self.functions.values()) - return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" * len(self.functions) + functions_text = "\n\n".join(str(func) for func in self.functions.values()) + return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" def __repr__(self) -> str: return f"""\ @@ -2451,7 +2554,7 @@ def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None: self._graph.insert_before(node, new_nodes) def sort(self) -> None: - """Topologically sort the nodes in the function.""" + """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time.""" self._graph.sort() # End of mutation methods diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index eaff506c5..802bf39de 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -244,6 +244,26 @@ def test_initialize(self): # Ensure repeated reads are consistent np.testing.assert_equal(tensor, self.data) + def test_release_does_not_invalidate_tensor(self): + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + base_dir=self.base_path, + name="input", + shape=_core.Shape(external_tensor.dims), + ) + self.assertEqual(tensor.dtype, ir.DataType.FLOAT) + self.assertEqual(tensor.tobytes(), self.data.tobytes()) + # Release tensor + tensor.release() + self.assertEqual(tensor.raw, None) + # Tensor can be re-loaded after release + self.assertEqual(tensor.tobytes(), self.data.tobytes()) + def test_initialize_with_relative_path(self): external_tensor = self.model.graph.initializer[0] external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) @@ -678,7 +698,6 @@ def test_it_is_added_to_a_graph_if_specified(self): (self.v0, self.v1), # type: ignore self.node.outputs, nodes=(self.node,), - opset_imports={"": 1}, ) self.assertIn(self.node, graph) @@ -798,6 +817,170 @@ def test_remove_safe_removes_uses_of_removed_nodes(self): # TODO(justinchuby): Test graph mutation methods + # Test topological sort. + # Graph structure: + # nodes: [node, ...] + # edges: [(predecessor_node, successor_node), ...] + # subgraphs: {node: [subgraph, ...]} + + def test_topological_sort_empty_graph(self): + graph = _core.Graph( + inputs=(), + outputs=(), + nodes=(), + ) + graph.sort() + self.assertEqual(tuple(graph), ()) + + def test_topological_sort_linear_dependencies(self): + # nodes=[1,2,3], edges=[(1,2),(2,3)] + v0 = _core.Value(name="v0") + node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) + node2 = _core.Node("", "Node2", inputs=(node1.outputs[0],), num_outputs=1) + node3 = _core.Node("", "Node3", inputs=(node2.outputs[0],), num_outputs=1) + graph = _core.Graph( + (v0,), + node3.outputs, + nodes=(node3, node2, node1), + ) + graph.sort() + sorted_nodes = tuple(graph) + expected_order = (node1, node2, node3) + self.assertEqual(sorted_nodes, expected_order) + + def test_topological_sort_independent_subgraphs(self): + # nodes=[1,2,3,4], edges=[(1,3),(2,4)] + v0 = _core.Value(name="v0") + v1 = _core.Value(name="v1") + node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) + node2 = _core.Node("", "Node2", inputs=(v1,), num_outputs=1) + node3 = _core.Node("", "Node3", inputs=(node1.outputs[0],), num_outputs=1) + node4 = _core.Node("", "Node4", inputs=(node2.outputs[0],), num_outputs=1) + graph = _core.Graph( + (v0, v1), + (node3.outputs[0], node4.outputs[0]), + nodes=(node4, node3, node2, node1), + ) + graph.sort() + sorted_nodes = tuple(graph) + expected_order = (node2, node4, node1, node3) + self.assertEqual(sorted_nodes, expected_order) + + def test_topological_sort_shared_successor(self): + # nodes=[1,2,3], edges=[(1,3),(2,3)] + v0 = _core.Value(name="v0") + node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) + node2 = _core.Node("", "Node2", inputs=(v0,), num_outputs=1) + node3 = _core.Node( + "", "Node3", inputs=(node1.outputs[0], node2.outputs[0]), num_outputs=1 + ) + graph = _core.Graph( + (v0,), + (node3.outputs[0],), + nodes=(node3, node2, node1), + ) + graph.sort() + sorted_nodes = tuple(graph) + expected_order = (node2, node1, node3) + self.assertEqual(sorted_nodes, expected_order) + + def _create_shared_predecessor_nodes( + self, + ) -> tuple[_core.Value, tuple[_core.Node, _core.Node, _core.Node]]: + # nodes=[0,1,2], edges=[(0,1),(0,2)] + v0 = _core.Value(name="v0") + node0 = _core.Node("", "Node0", inputs=(v0,), num_outputs=1) + node1 = _core.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1) + node2 = _core.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1) + return v0, (node0, node1, node2) + + @parameterized.parameterized.expand( + [ + ("012", (0, 1, 2), (0, 1, 2)), + ("021", (0, 2, 1), (0, 2, 1)), + ("102", (1, 0, 2), (0, 1, 2)), + ("120", (1, 2, 0), (0, 1, 2)), + ("201", (2, 0, 1), (0, 2, 1)), + ("210", (2, 1, 0), (0, 2, 1)), + ] + ) + def test_topological_sort_shared_predecessor( + self, _: str, initial_order: tuple[int], expected_order: tuple[int] + ): + v0, nodes = self._create_shared_predecessor_nodes() + graph = _core.Graph((v0,), (), nodes=[nodes[i] for i in initial_order]) + graph.sort() + sorted_nodes = list(graph) + self.assertEqual(sorted_nodes, [nodes[i] for i in expected_order]) + + def test_topological_sort_cycle_detection(self): + # nodes=[1,2,3], edges=[(1,2),(2,3),(3,2)] + v0 = _core.Value(name="v0") + node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) + node2 = _core.Node("", "Node2", inputs=(node1.outputs[0], v0), num_outputs=1) + node3 = _core.Node("", "Node3", inputs=(node2.outputs[0],), num_outputs=1) + node2.replace_input_with(1, node3.outputs[0]) + graph = _core.Graph( + (v0,), + (node3.outputs[0],), + nodes=(node1, node2, node3), + ) + with self.assertRaises(ValueError): + graph.sort() + + def test_topological_sort_subgraph(self): + # main_graph: nodes=[a,b,c,d,>,if], edges=[(a,>),(b,>),(>,if)], subgraphs={if:[then_graph,else_graph]} + # then_graph: nodes=[sub], edges=[(c,sub),(d,sub)] + # else_graph: nodes=[add], edges=[(c,add),(d,add)] + v0 = _core.Value(name="va") + v1 = _core.Value(name="vb") + v2 = _core.Value(name="vc") + v3 = _core.Value(name="vd") + node0 = _core.Node("", "a", inputs=(v0,), num_outputs=1) + node1 = _core.Node("", "b", inputs=(v1,), num_outputs=1) + node2 = _core.Node("", "c", inputs=(v2,), num_outputs=1) + node3 = _core.Node("", "d", inputs=(v3,), num_outputs=1) + node4 = _core.Node( + "", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 + ) + node5 = _core.Node( + "", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 + ) + node6 = _core.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) + then_graph = _core.Graph( + inputs=(node2.outputs[0], node3.outputs[0]), + outputs=(node4.outputs[0],), + nodes=(node4,), + name="then_graph", + ) + else_graph = _core.Graph( + inputs=(node2.outputs[0], node3.outputs[0]), + outputs=(node5.outputs[0],), + nodes=(node5,), + name="else_graph", + ) + node7 = _core.Node( + "", + "if", + inputs=(node6.outputs[0],), + num_outputs=1, + attributes=[ + ir.AttrGraph("then_branch", then_graph), + ir.AttrGraph("else_branch", else_graph), + ], + ) + main_graph_rev = _core.Graph( + inputs=(v0, v1, v2, v3), + outputs=(node7.outputs[0],), + nodes=(node7, node6, node3, node2, node1, node0), # if, >, d, c, b, a + name="main_graph_rev", + ) + main_graph_rev.sort() + self.assertEqual( + tuple(node.op_type for node in tuple(main_graph_rev)), + ("d", "c", "b", "a", ">", "if"), + ) + class TypeTest(unittest.TestCase): @parameterized.parameterized.expand( diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py index 6152491b6..75a7e34bc 100644 --- a/onnxscript/ir/_external_data.py +++ b/onnxscript/ir/_external_data.py @@ -100,6 +100,7 @@ def _load_external_data_file( if os.path.samefile(tensor.path, os.path.join(base_path, relative_path)): # Copy the data as the .numpy() call references data from a file whose data is eventually modified tensor_data = external_tensor.numpy().copy() + external_tensor.release() tensor = _core.Tensor( tensor_data, name=external_tensor.name, dtype=external_tensor.dtype ) @@ -165,6 +166,8 @@ def _save_external_data( current_offset = tensor_info.offset assert tensor is not None raw_data = tensor.tobytes() + if isinstance(tensor, _core.ExternalTensor): + tensor.release() # Pad file to required offset if needed file_size = data_file.tell() if current_offset > file_size: @@ -223,6 +226,7 @@ def convert_tensors_to_external( path = os.path.join(base_path, relative_path) # Check if file path is valid, and create subsequent subdirectories within the path if they don't exist os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_file_created = False # Check if file exists. Load pre-existing external data if it does. if os.path.exists(path): # Check if any tensor in the model is using the destination file @@ -241,6 +245,7 @@ def convert_tensors_to_external( os.makedirs(tmp_path, exist_ok=True) # If exisiting external tensors are not loaded to memory, copy the external data to a temporary location os.rename(path, os.path.join(tmp_path, relative_path)) + tmp_file_created = True for tensor in tensors: if ( isinstance(tensor, _core.ExternalTensor) @@ -270,6 +275,12 @@ def convert_tensors_to_external( external_tensors[i] for i in sorted(range(len(external_tensors)), key=lambda i: sorted_indices[i]) ] + + # Clean-up temporary file if it is created + tmp_path = os.path.join(base_path, "tmp", relative_path) + if os.path.exists(tmp_path) and tmp_file_created: + os.remove(tmp_path) + return external_tensors diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py index 3cf27aa0c..afcf32b20 100644 --- a/onnxscript/ir/_external_data_test.py +++ b/onnxscript/ir/_external_data_test.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os +import sys import tempfile import typing import unittest @@ -115,7 +116,10 @@ class OffloadExternalTensorTest(unittest.TestCase): def setUp(self): # File paths - self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + if sys.version_info[:2] >= (3, 10): + self.temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) # pylint: disable=consider-using-with + else: + self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with self.external_data_name = "external_tensors.bin" self.base_path = self.temp_dir.name self.ext_data_1 = "external_data_1.bin" @@ -136,7 +140,15 @@ def setUp(self): self.model_with_mixed_external_data = self._model_with_mixed_external_data() def tearDown(self) -> None: - self.temp_dir.cleanup() + # Handle exceptions for windows and python versions < 3.10 + try: + self.temp_dir.cleanup() + except PermissionError as e: + print(f"PermissionError: {e}") + except FileNotFoundError as e: + print(f"FileNotFoundError: {e}") + except Exception as e: # pylint: disable=broad-exception-caught + print(f"An unexpected error occurred: {e}") def _simple_model(self) -> ir.Model: tensor1 = ir.Tensor( diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/optimizer/_inliner.py index 6a6cd9bb3..a43b8cb0f 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/optimizer/_inliner.py @@ -171,7 +171,7 @@ def __init__(self, model: ir.Model) -> None: self.node_context: dict[ir.Node, CallStack] = {} - def _instantiate_call(self, node: ir.Node, call_site_id: str) -> NodeReplacement: + def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement: id = node.op_identifier() function = self._functions[id] diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 7a4b00798..21cee515d 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -23,6 +23,14 @@ def div_by_1(op, x): return x / 1 +def dropout_zero(op, x): + return op.Dropout(x, ratio=0.0) + + +def dropout_inference(op, x): + return op.Dropout(x, training_mode=False) + + # Replacement def identity(op, x): return op.Identity(x) @@ -32,6 +40,8 @@ def identity(op, x): add_0_rule = pattern.RewriteRule(add_0, identity) sub_0_rule = pattern.RewriteRule(sub_0, identity) div_by_1_rule = pattern.RewriteRule(div_by_1, identity) +dropout_zero_rule = pattern.RewriteRule(dropout_zero, identity) +dropout_inference_rule = pattern.RewriteRule(dropout_inference, identity) # TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops rules = pattern.RewriteRuleSet( @@ -40,5 +50,7 @@ def identity(op, x): *add_0_rule.commute(), sub_0_rule, div_by_1_rule, + dropout_zero_rule, + dropout_inference_rule, ] ) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/no_op_test.py index 92172ec1f..4e509e7f3 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/no_op_test.py @@ -177,6 +177,26 @@ def test_div_one_should_become_no_op_with_initializer( """ ) + @parameterized.parameterized.expand( + [ + ("dropout zero ratio", "ratio=0.0"), + ("dropout inference", "training_mode=0"), + ("dropout inference with positive ratio", "ratio=0.42, training_mode=0"), + ("dropout training with zero ratio", "ratio=0.0, training_mode=1"), + ] + ) + def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: str): + self._check( + f""" + + agraph (float16[M] input) => (float16[M] output) + {{ + output = Dropout<{attribute}>(input) + }} + """ + ) + # TODO: Test the negative cases + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 0abced612..756a74027 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1325,7 +1325,6 @@ def sample_inputs__scaled_dot_product_efficient_attention( dim_4_q_shape = (batch, num_heads, seq_q, head_dim) dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) - shape_attn_bias = (batch, num_heads, seq_q, seq_kv) qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] @@ -1339,7 +1338,7 @@ def sample_inputs__scaled_dot_product_efficient_attention( make(shape_q), make(shape_kv), make(shape_kv), - attn_bias=make(shape_attn_bias), + attn_bias=None, # TODO: Add attn_bias is_causal=is_causal, dropout_p=dropout_p, compute_log_sumexp=compute_log_sumexp, @@ -1998,6 +1997,21 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.BinaryUfuncInfo( + "ops.aten.floor_divide", + aten_name="floor_divide", + dtypes=common_dtype.floating_types_and_half(), + rhs_make_tensor_kwargs=dict(exclude_zero=True), + ), + opinfo_core.BinaryUfuncInfo( + "ops.aten.floor_divide.int", + aten_name="floor_divide", + op=torch.ops.aten.floor_divide, + dtypes=common_dtype.integral_types(), + # Create only positive inputs + lhs_make_tensor_kwargs=dict(low=0), + rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), + ), opinfo_core.OpInfo( "ops.aten.index.Tensor", aten_name="index.Tensor", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3f9576745..7a475c9ad 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -841,11 +841,12 @@ def _where_input_wrangler( ), TorchLibOpInfo("flatten", core_ops.aten_flatten), TorchLibOpInfo("floor", core_ops.aten_floor), - TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide).xfail( + TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide).skip( dtypes=(torch.float16,), test_class_name="TestOutputConsistencyEager", reason="fixme: off-by-one issue due to numerical precision. https://github.com/microsoft/onnxscript/issues/989", ), + TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full),