Skip to content

Commit

Permalink
Merge branch 'rama/inliner' of https://github.com/microsoft/onnxscript
Browse files Browse the repository at this point in the history
…into rama/inliner
  • Loading branch information
gramalingam committed Sep 5, 2024
2 parents 31f69ed + 6bdfc6a commit 794bb2e
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 42 deletions.
31 changes: 5 additions & 26 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
- name: Pull Test Data
run: git lfs pull
- name: Run tests
run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml
run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junitxml junit.xml
env:
CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
Expand All @@ -80,12 +80,11 @@ jobs:
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v3
- name: Upload test results to Codecov
if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
name: Test Results (${{ matrix.name }}-${{ matrix.os }})
path: pytest.xml
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload torchlib error reports
if: always()
uses: actions/upload-artifact@v3
Expand Down Expand Up @@ -161,23 +160,3 @@ jobs:
echo "Update readme by running `python docs/update_readme.py`"
exit 1
fi
publish-test-results:
name: "Publish Tests Results to Github"
needs: test
runs-on: ubuntu-latest
permissions:
checks: write
# only needed unless run with comment_mode: off
pull-requests: write
if: always()
steps:
- name: Download Artifacts
uses: actions/download-artifact@v3
with:
path: artifacts

- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@v2
with:
files: "artifacts/**/*.xml"
17 changes: 13 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3564,7 +3564,7 @@ def aten_flipud(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::floor")
@torch_op("aten::floor", traceable=True)
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""floor(Tensor self) -> Tensor"""

Expand All @@ -3578,13 +3578,22 @@ def python_math_floor(self: TFloatOrBFloat16) -> TInt:
return op.Cast(floor, to=INT64.dtype)


@torch_op(("aten::floor_divide", "_operator::floordiv"))
@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

return op.Floor(op.Div(self, other))


@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
def aten_floor_divide_int(self: TInt, other: TInt) -> TInt:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

# We implement floor_divide only for positive inputs (using integer division)
# because that is the usual intended case and is the most efficient.
return op.Div(self, other)


def aten_fmax(self: TensorType, other: TensorType) -> TensorType:
"""fmax(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -3597,14 +3606,14 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"))
@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"), traceable=True)
def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8:
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Mod(self, other, fmod=1)


@torch_op("aten::frac")
@torch_op("aten::frac", traceable=True)
def aten_frac(self: TFloat) -> TFloat:
"""frac(Tensor self) -> Tensor
Expand Down
113 changes: 108 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import abc
import contextlib
import dataclasses
import heapq
import math
import mmap
import os
Expand Down Expand Up @@ -670,6 +671,13 @@ def tobytes(self) -> bytes:
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def release(self) -> None:
"""Delete all references to the memory buffer and close the memory-mapped file."""
self._array = None
if self.raw is not None:
self.raw.close()
self.raw = None

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
Expand Down Expand Up @@ -1977,8 +1985,103 @@ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None
self._nodes.insert_before(node, new_nodes)

def sort(self) -> None:
"""Topologically sort the nodes in the graph."""
raise NotImplementedError("Not implemented yet")
"""Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time.
This sort is stable. It preserves the original order as much as possible.
Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort
Raises:
ValueError: If the graph contains a cycle, making topological sorting impossible.
"""
# Obtain all nodes from the graph and its subgraphs for sorting
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
# Store the sorted nodes of each subgraph
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
}
# TODO: Explain why we need to store direct predecessors and children and why
# we only need to store the direct ones

# The depth of a node is defined as the number of direct children it has
node_depth: dict[Node, int] = dict.fromkeys(nodes, 0)
# Direct predecessors of a node
node_predecessors: dict[Node, list[Node]] = {node: [] for node in nodes}
# Store the negative index of the nodes because heapq is a min heap and we
# want to pop the node with largest index value first, effectively turning
# it to a max heap
neg_node_index: dict[Node, int] = {node: -i for i, node in enumerate(nodes)}

def add_predecessor(child: Node, predecessor: Node | None) -> None:
"""Add a predecessor of a node, and increment the depth of the predecessor."""
if predecessor is None:
return
node_predecessors[child].append(predecessor)
node_depth[predecessor] += 1

# 1. Build the direct predecessors of each node and the depth of each node
# for sorting topolocally using Kahn's algorithm.
# Note that when a node contains graph attributes (aka. has subgraphs),
# we consider all nodes in the subgraphs *predecessors* of this node. This
# way we ensure the implicit dependencies of the subgraphs are captured
# as predecessors of the node.
for node in nodes:
# All producers of input values are considered as direct predecessors.
for input_value in node.inputs:
if input_value is None:
continue
predecessor_node = input_value.producer()
add_predecessor(node, predecessor_node)
# All nodes in attribute graphs are considered as direct predecessors.
for attr in node.attributes.values():
if not isinstance(attr, Attr):
continue
# A nice thing about this algorithm is that we only need to record
# direct predecessors. This continues to be true even with subgraphs.
# When a node in a subgraph (a) contains its own subgraphs (b), the
# node in subgraphs (b) are guranteed to appear before the node
# in (a).
if attr.type == _enums.AttributeType.GRAPH:
for predecessor_node in attr.value:
add_predecessor(node, predecessor_node)
elif attr.type == _enums.AttributeType.GRAPHS:
for attribute_graph in attr.value:
for predecessor_node in attribute_graph:
add_predecessor(node, predecessor_node)

# 2. Priority Queue: Track nodes with zero direct children in a priority queue,
# using NEGATIVE original index for ordering.
# This ensures nodes appearing LATER in the original order are processed EARLIER.
# We get REVERSED topological order of each subgraph.
priority_queue: list[tuple[int, Node]] = [
(neg_node_index[node], node) for node in nodes if node_depth[node] == 0
]
heapq.heapify(priority_queue)

# 3. Topological Sort:
num_of_sorted_nodes = 0
while priority_queue:
# Pop the node with the most negative index and add it to the sorted nodes by subgraph.
_, current_node = heapq.heappop(priority_queue)
assert current_node.graph is not None
sorted_nodes_by_graph[current_node.graph].append(current_node)
num_of_sorted_nodes += 1
# Decrement the depth of its predecessors. If any predecessor node has zero direct children, push it into the queue.
for predecessor_node in node_predecessors[current_node]:
node_depth[predecessor_node] -= 1
if node_depth[predecessor_node] == 0:
heapq.heappush(
priority_queue, (neg_node_index[predecessor_node], predecessor_node)
)

# 4. Cycle Check: Ensure all nodes are processed. If not, raise a ValueError indicating a cycle.
if num_of_sorted_nodes != len(nodes):
raise ValueError("Graph contains a cycle, topological sort is not possible.")

# 5. Reverse: Reverse the sorted nodes of each subgraph to get the topological order.
for graph, sorted_nodes in sorted_nodes_by_graph.items():
# The graph container ensures all the nodes are unique so we can safely extend
graph.extend(reversed(sorted_nodes))

# End of mutation methods

Expand Down Expand Up @@ -2261,8 +2364,8 @@ def __str__(self) -> str:
model_version={self.model_version!r},
>"""
graph_text = str(self.graph)
functions_text = ",\n\n".join(str(func) for func in self.functions.values())
return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" * len(self.functions)
functions_text = "\n\n".join(str(func) for func in self.functions.values())
return f"{signature}\n{graph_text}" + f"\n\n{functions_text}"

def __repr__(self) -> str:
return f"""\
Expand Down Expand Up @@ -2451,7 +2554,7 @@ def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
self._graph.insert_before(node, new_nodes)

def sort(self) -> None:
"""Topologically sort the nodes in the function."""
"""Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time."""
self._graph.sort()

# End of mutation methods
Expand Down
Loading

0 comments on commit 794bb2e

Please sign in to comment.