Skip to content

Commit

Permalink
Fixes/updates to get tests passing
Browse files Browse the repository at this point in the history
Signed-off-by: Ian <[email protected]>
  • Loading branch information
IanNod committed Sep 28, 2024
1 parent 38f40b6 commit 0392be8
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 46 deletions.
2 changes: 1 addition & 1 deletion shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def flat_wrapped_f(*args):

def _split_py_arg(self, arg) -> Tuple[Value, Any]:
if isinstance(arg, IrTensor):
meta_tensor, _ = arg._to_meta_tensor()
meta_tensor = arg._to_meta_tensor()
return arg.ir_value, meta_tensor

raise TypeError(f"Unsupported argument to jittable: {arg}")
Expand Down
3 changes: 2 additions & 1 deletion shark_turbine/aot/passes/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
GraphModule,
)
from torch.fx.experimental import proxy_tensor
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -43,7 +44,7 @@
def functorch_functionalize(gm_callable: Any, *args) -> GraphModule:
functionalized_callable = _functionalize_callabale(gm_callable)
# TODO: There is more of a dance needed if the user has entered with a fake_mode.
with proxy_tensor.maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
new_gm = proxy_tensor.make_fx(
functionalized_callable,
decomposition_table={},
Expand Down
20 changes: 5 additions & 15 deletions shark_turbine/aot/support/procedural/iree_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def tensor_reshape(
result_value = flow_d.TensorReshapeOp(
result_type,
source.ir_value,
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
[], # forcing empty list for dynamic dims until supported in CompiledModule
result_dynamic_dims,
).result
result = IrImmediateTensor(result_value, dtype=source.dtype)
Expand Down Expand Up @@ -276,7 +276,7 @@ def tensor_slice(
result_value = flow_d.TensorSliceOp(
result_type,
source_value,
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
[], # forcing empty list for dynamic dims until supported in CompiledModule
start_index_values,
length_values,
result_dynamic_dims,
Expand All @@ -295,26 +295,19 @@ def tensor_update(
"""Applies an update to a target at start_indices and returns the mutated target."""
constant_cache: Dict[int, Value] = {}
target = cast_tensor_value(target)
target_dynamic_dims = target.get_only_dynamic_dim_values(
constant_cache=constant_cache
)
update = cast_tensor_value(update)
update_dynamic_dims = update.get_only_dynamic_dim_values(
constant_cache=constant_cache
)
start_index_dim_values = [
cast_index_value(idx, constant_cache=constant_cache)
for idx in start_indices
]
result_value = flow_d.TensorUpdateOp(
target.ir_value,
target_dynamic_dims,
[], # forcing empty list for dynamic dims until supported in CompiledModule
start_index_dim_values,
update.ir_value,
update_dynamic_dims,
[], # forcing empty list for updated dynamic dims until supported in CompiledModule
).result
result = IrImmediateTensor(result_value, target.dtype)
result.set_dynamic_dim_values(target_dynamic_dims)
return result

@emitter
Expand Down Expand Up @@ -342,11 +335,8 @@ def tensor_splat(

@emitter
def tensor_trace(self, key: str, *ts: BuildableTensorType):
dynamic_dims = []
for t in ts:
dynamic_dims.extend(t.get_only_dynamic_dim_values())
ts = tuple(cast_tensor_value(t).ir_value for t in ts)
flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims)
flow_d.TensorTraceOp(StringAttr.get(key), ts, [])


# Circular imports to resolve typing.
Expand Down
4 changes: 2 additions & 2 deletions shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from torch.func import functionalize
from typing import List, Optional
from typing import List, Optional, Mapping

from .decompositions import DEFAULT_DECOMPOSITIONS

Expand All @@ -15,7 +15,7 @@ def apply_decompositions(
if decompose_ops is None:
return gm

decompositions = get_decompositions(decompose_ops)
decompositions: Mapping = get_decompositions(decompose_ops)
gm = make_fx(
functionalize(gm),
decomposition_table=decompositions,
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]:
if device:
gcn_arch_name = gcn_arch_name
device.compile_target_flags = device.compile_target_flags + (
f"--iree-rocm-target-chip={gcn_arch_name}",
f"--iree-hip-target={gcn_arch_name}",
)
device._recompute_target_keys()
return device
Expand Down
4 changes: 4 additions & 0 deletions tests/aot/functionalize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import unittest
import pytest

import torch

Expand Down Expand Up @@ -34,6 +35,9 @@ def compute():
print(module_str)
self.assertNotIn("add_", module_str)

@pytest.mark.xfail(
reason="CompiledModule dynamic dims no longer supported in latest torch versions"
)
def testDynamicDims(self):
class ProcArgsModule(CompiledModule):
def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)):
Expand Down
50 changes: 28 additions & 22 deletions tests/aot/iree_procedural_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import unittest
import pytest

import torch

Expand Down Expand Up @@ -44,59 +45,55 @@ def foobar(self, a=AbstractTensor(None, 3)):

def testTensorEmpty(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractIndex):
empty = IREE.tensor_empty(x, 16)
def foobar(self):
empty = IREE.tensor_empty(1, 16)
dim0 = IREE.tensor_dim(empty, 0)
return empty, dim0

inst = BasicModule(context=Context(), import_to=None)
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("%0 = flow.tensor.empty : tensor<?x16xf32>{%arg0}", module_str)
# NOTE: We are testing below that the dynamic dimension is associated
# and used from the input vs being recalculated.
self.assertIn("return %0, %arg0 : tensor<?x16xf32>, index", module_str)
self.assertIn("%0 = flow.tensor.empty : tensor<1x16xf32>", module_str)
self.assertIn("return %0, %dim : tensor<1x16xf32>, index", module_str)

def testTensorSplat(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractF32):
empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
def foobar(self, y=AbstractF32):
empty = IREE.tensor_splat(2, 34, value=y, dtype=torch.float32)
dim0 = IREE.tensor_dim(empty, 0)
return empty, dim0

inst = BasicModule(context=Context(), import_to=None)
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"%0 = flow.tensor.splat %arg1 : tensor<?x34xf32>{%arg0}", module_str
)
self.assertIn("%0 = flow.tensor.splat %arg0 : tensor<2x34xf32>", module_str)
# NOTE: We are testing below that the dynamic dimension is associated
# and used from the input vs being recalculated.
self.assertIn("return %0, %arg0 : tensor<?x34xf32>, index", module_str)
self.assertIn("return %0, %dim : tensor<2x34xf32>, index", module_str)

def testTensorSplatCasting(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractIndex):
empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.int32)
def foobar(self, y=AbstractIndex):
empty = IREE.tensor_splat(8, 34, value=y, dtype=torch.int32)
dim0 = IREE.tensor_dim(empty, 0)
return empty, dim0

inst = BasicModule(context=Context(), import_to=None)
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("%0 = arith.index_castui %arg1 : index to i32", module_str)
self.assertIn("%1 = flow.tensor.splat %0 : tensor<?x34xi32>{%arg0}", module_str)
self.assertIn("%0 = arith.index_castui %arg0 : index to i32", module_str)
self.assertIn("%1 = flow.tensor.splat %0 : tensor<8x34xi32>", module_str)

def testTensorTrace(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)):
def foobar(self, x=AbstractTensor(5), y=AbstractTensor(3)):
IREE.tensor_trace("DEBUG", x, y)

inst = BasicModule(context=Context(), import_to=None)
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
'flow.tensor.trace "DEBUG" = [%arg0 : tensor<?xf32>{%dim}, %arg1 : tensor<3xf32>]',
'flow.tensor.trace "DEBUG" = [%arg0 : tensor<5xf32>, %arg1 : tensor<3xf32>]',
module_str,
)

Expand Down Expand Up @@ -128,6 +125,9 @@ def foobar(self, x=AbstractTensor(3, 4)):
module_str,
)

@pytest.mark.xfail(
reason="CompiledModule dynamic dims no longer supported in latest torch versions"
)
def testTensorSliceDynamicIndex(self):
class SliceDynamicIndex(CompiledModule):
def foobar(self, x=AbstractIndex):
Expand All @@ -142,6 +142,9 @@ def foobar(self, x=AbstractIndex):
module_str,
)

@pytest.mark.xfail(
reason="CompiledModule dynamic dims no longer supported in latest torch versions"
)
def testTensorSliceDynamicLength(self):
class SliceDynamicIndex(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractIndex):
Expand Down Expand Up @@ -175,6 +178,9 @@ def foobar(
module_str,
)

@pytest.mark.xfail(
reason="CompiledModule dynamic dims no longer supported in latest torch versions"
)
def testTensorUpdateDynamic(self):
class UpdateDynamic(CompiledModule):
def foobar(
Expand All @@ -199,16 +205,16 @@ def foobar(

def testTensorReshape(self):
class ReshapeModule(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractIndex):
empty = IREE.tensor_empty(x, 16)
reshaped = IREE.tensor_reshape(empty, 1, y, y)
def foobar(self):
empty = IREE.tensor_empty(4, 16)
reshaped = IREE.tensor_reshape(empty, 1, 2, 2)
return reshaped

inst = ReshapeModule(context=Context(), import_to=None)
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"flow.tensor.reshape %0 : tensor<?x16xf32>{%arg0} -> tensor<1x?x?xf32>{%arg1, %arg1}",
"flow.tensor.reshape %0 : tensor<4x16xf32> -> tensor<1x2x2xf32>",
module_str,
)

Expand Down
5 changes: 5 additions & 0 deletions tests/aot/jittable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import unittest
import pytest

import torch

Expand Down Expand Up @@ -72,6 +73,9 @@ def compute(*, a, b):
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)

@pytest.mark.xfail(
reason="CompiledModule dynamic dims no longer supported in latest torch versions"
)
def testDynamicDims(self):
class DynamicDimsModule(CompiledModule):
def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)):
Expand Down Expand Up @@ -108,6 +112,7 @@ def compute(a, b):
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)

@pytest.mark.xfail(reason="CompiledModule dynamic dims no longer supported")
def testIrImmediateTensorAsInputToDynamicDims(self):
class ProcArgsModule(CompiledModule):
def dynamic_dim(self, x=AbstractIndex):
Expand Down
1 change: 0 additions & 1 deletion tests/dynamo/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def main():
opt(example_tokens, start_pos)


@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK-Turbine/issues/221")
class ModelTests(unittest.TestCase):
def testLLama(self):
main()
Expand Down
3 changes: 0 additions & 3 deletions tests/examples/aot_mlp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ class AOTMLPTest(unittest.TestCase):
def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_simple.py")

def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_dynamic.py")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit 0392be8

Please sign in to comment.