Skip to content

Commit

Permalink
Add make_fx_tosa variant to end2end tests (llvm#2240)
Browse files Browse the repository at this point in the history
* Add make_fx_tosa variant to end2end tests

* e2e_testing/xfail_sets.py: Add make_fx_tosa xfail for stable
  • Loading branch information
mgehre-amd authored Jul 13, 2023
1 parent 91c6454 commit f8e75f6
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
5 changes: 4 additions & 1 deletion build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ function test_in_tree() {
exit 1
;;
esac


echo ":::: Run make_fx + TOSA e2e integration tests"
python -m e2e_testing.main --config=make_fx_tosa -v

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v

Expand Down
7 changes: 6 additions & 1 deletion e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from .xfail_sets import (
LINALG_XFAIL_SET,
MAKE_FX_TOSA_PASS_SET,
STABLEHLO_PASS_SET,
TOSA_PASS_SET,
LTC_XFAIL_SET,
Expand All @@ -42,7 +43,7 @@
register_all_tests()

def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
Expand Down Expand Up @@ -94,6 +95,10 @@ def main():
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
crashing_set = set()
elif args.config == "make_fx_tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
crashing_set = set()
elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
Expand Down
36 changes: 36 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# might be used to keep more elaborate sets of testing configurations).

from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
from torch_mlir._version import torch_version_for_comparison, version

LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS

Expand Down Expand Up @@ -1113,6 +1114,41 @@
"ChunkListUnpackUneven_Module_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
### Tests additionally passing in make_fx_tosa
"NativeGroupNormBackwardModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa

# failed to lower torch.aten.empty.memory_format
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",

# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",

# failed to legalize operation 'torch.aten.max_pool2d_with_indices
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",

# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
}

if torch_version_for_comparison() < version.parse("2.1.0.dev"):
MAKE_FX_TOSA_PASS_SET -= {
# 'tensor.expand_shape' op expected rank expansion, but found source rank 1 >= result rank 1
"ReshapeCollapseModule_basic",
}

LTC_CRASHING_SET = {
# https://github.com/llvm/torch-mlir/issues/2186
"Add_Module_basic"
Expand Down
19 changes: 16 additions & 3 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torch._functorch.compile_utils import strip_overloads
import torch
import torch.fx
from torch_mlir.dynamo import _get_decomposition_table
from torch.fx.experimental.proxy_tensor import make_fx

from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
Expand Down Expand Up @@ -225,8 +227,11 @@ def _get_for_tracing(
# they know what they are doing and that their trace is
# correct for any specific concrete size.
shape = [s if s != -1 else 7 for s in arg.shape]
example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype))
if len(shape) == 0:
example_args_for_trace.append(torch.tensor(1))
else:
example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype))
else:
assert isinstance(arg, torch.Tensor)
example_args_for_trace.append(arg)
Expand Down Expand Up @@ -313,7 +318,8 @@ def compile(model: torch.nn.Module,
ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library: Iterable[Callable] = [],
verbose: bool = False):
verbose: bool = False,
use_make_fx: bool = False):
"""Convert a PyTorch model to MLIR.
Args:
Expand Down Expand Up @@ -367,6 +373,13 @@ def compile(model: torch.nn.Module,
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])

if use_make_fx:
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"]
model = make_fx(
model,
decomposition_table=_get_decomposition_table())(*args)


# For FX-based models, automatically strip overloads.
if isinstance(model, torch.fx.GraphModule):
strip_overloads(model)
Expand Down
5 changes: 3 additions & 2 deletions python/torch_mlir_e2e_test/configs/tosa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ class TosaBackendTestConfig(TestConfig):
This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level.
"""
def __init__(self, backend: TosaBackend):
def __init__(self, backend: TosaBackend, use_make_fx: bool = False):
super().__init__()
self.backend = backend
self.use_make_fx = use_make_fx

def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
program, example_args, output_type="tosa")
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)

return self.backend.compile(module)

Expand Down

0 comments on commit f8e75f6

Please sign in to comment.