forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[fx] Implement auto_functionalized higher order op. (llvm#3063)
* Also adds the basic scaffolding for handling more of these, which will be needed for cond, while, etc. * Refactors some of the support in the generic OpOverload emitter so it can be shared with these other special forms. This has been on my list for a while, but it just so happens that as part of upgrading to PyTorch 2.3 and a pure upstream flow in Turbine, we were using a feature that required integration with auto_functionalized. This is perhaps the "weirdest" of the higher-order ops and a poor place to start, but needs must. We have testing for this in Turbine. Full support in Turbine has an entire custom ops facility. I've reduced this down to a unit test in torch-mlir.
- Loading branch information
1 parent
11eaba3
commit e2343cf
Showing
4 changed files
with
246 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
# RUN: %PYTHON %s | FileCheck %s | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.export | ||
import torch.nn as nn | ||
|
||
from torch_mlir import fx | ||
|
||
from torch_mlir.ir import ( | ||
Operation, | ||
) | ||
|
||
|
||
LIBRARY = torch.library.Library("torch_mlir_test", "DEF") | ||
|
||
LIBRARY.define("inplace_modify(Tensor(a!) x) -> ()") | ||
LIBRARY.define("inplace_modify_calc(Tensor(a!) x) -> (Tensor)") | ||
|
||
|
||
def inplace_modify_calc_meta(x): | ||
return torch.empty_like(x) | ||
|
||
|
||
LIBRARY.impl("inplace_modify_calc", inplace_modify_calc_meta, "Meta") | ||
|
||
|
||
def run(f): | ||
print(f"{f.__name__}") | ||
print("-" * len(f.__name__)) | ||
f() | ||
print() | ||
|
||
|
||
# CHECK-LABEL: test_auto_functionalized_hop | ||
@run | ||
def test_auto_functionalized_hop(): | ||
class Basic(nn.Module): | ||
def forward(self, x): | ||
torch.ops.torch_mlir_test.inplace_modify(x) | ||
return x * x | ||
|
||
m = fx.export_and_import( | ||
Basic(), | ||
torch.randn(3, 4), | ||
experimental_support_mutation=True, | ||
# TODO: ExportedProgram.run_decompositions() seems to have trouble | ||
# with mode selection and Python higher order op implementations. | ||
# Isolate and report upstream. | ||
# Raises: | ||
# File "/home/stella/v/Dev/lib/python3.11/site-packages/torch/_ops.py", line 323, in dispatch | ||
# assert ( | ||
# AssertionError: Current active mode <torch._subclasses.functional_tensor.FunctionalTensorMode object at 0x7a1106504fd0> not registered | ||
decomposition_table=[], | ||
) | ||
# CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> | ||
# CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]] | ||
print(m) | ||
m.operation.verify() | ||
|
||
|
||
# CHECK-LABEL: test_auto_functionalized_one_ret | ||
@run | ||
def test_auto_functionalized_one_ret(): | ||
class Basic(nn.Module): | ||
def forward(self, x): | ||
y = torch.ops.torch_mlir_test.inplace_modify_calc(x) | ||
return x * y | ||
|
||
m = fx.export_and_import( | ||
Basic(), | ||
torch.randn(3, 4), | ||
experimental_support_mutation=True, | ||
# TODO: ExportedProgram.run_decompositions() seems to have trouble | ||
# with mode selection and Python higher order op implementations. | ||
# Isolate and report upstream. | ||
# Raises: | ||
# File "/home/stella/v/Dev/lib/python3.11/site-packages/torch/_ops.py", line 323, in dispatch | ||
# assert ( | ||
# AssertionError: Current active mode <torch._subclasses.functional_tensor.FunctionalTensorMode object at 0x7a1106504fd0> not registered | ||
decomposition_table=[], | ||
) | ||
# CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) | ||
# CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0 | ||
print(m) | ||
m.operation.verify() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters