Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tripy] Add some basic greedy constant-folding optimizations during t… #42

Merged
merged 4 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tripy/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ test = [
"pytest-virtualenv==1.7.0",
"pytest-cov==4.1.0",
"jax[cuda12_local]==0.4.23",
"jaxlib==0.4.23+cuda12.cudnn89",
"coverage==7.4.1",
"vulture==2.11",
]
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/flat_ir/ops/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def test_mlir_conv(self, conv_flat_ir, padding, stride, groups, rhs_dilation):
"dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], "
f"window = {{stride = {stride}, pad = {padding}, rhs_dilate = {rhs_dilation}}} "
f"{{batch_group_count = 1 : i64, feature_group_count = {groups} : i64}} "
f": (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>"
)
assert re.search(expected_op_call, target) and expected_op_signature in target

Expand All @@ -162,6 +161,6 @@ def test_mlir_conv_transpose(self, conv_transpose_flat_ir, padding, stride, grou
padding = new_padding
rhs_dilation = list(rhs_dilation)
expected_op_call = rf"stablehlo.convolution\(%\d+, %\d+\)"
expected_op_signature = f"dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {{stride = {[1] * len(stride)}, pad = {padding}, lhs_dilate = {stride}, rhs_dilate = {rhs_dilation}}} {{batch_group_count = 1 : i64, feature_group_count = {groups} : i64}} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>"
expected_op_signature = f"dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {{stride = {[1] * len(stride)}, pad = {padding}, lhs_dilate = {stride}, rhs_dilate = {rhs_dilation}}} {{batch_group_count = 1 : i64, feature_group_count = {groups} : i64}}"
target = str(conv_transpose_flat_ir[0].to_mlir())
assert re.search(expected_op_call, target) and expected_op_signature in target
14 changes: 13 additions & 1 deletion tripy/tripy/backend/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ def make_mlir_tensor(
get_mlir_dtype(dtype),
)

def get_constant_value(arg) -> Optional[ir.DenseElementsAttr]:
from mlir_tensorrt.compiler.dialects import stablehlo

if isinstance(arg, ir.Value) and ir.OpResult.isinstance(arg):
arg = ir.OpResult(arg).owner

if isinstance(arg, ir.Operation):
arg = arg.opview

if isinstance(arg, stablehlo.ConstantOp):
return arg.value

return None

def remove_sym_attr(mlir_text: str) -> str:
return re.sub(r"module @\S+ {", "module {", mlir_text)
Expand Down Expand Up @@ -282,7 +295,6 @@ def get_compiler(cls):

@utils.log_time
def get_shape_of_dynamic_trace_tensor(self, trace_tensor):

from tripy.flat_ir.flat_ir import FlatIR
from tripy.frontend.utils import topological_sort
import copy
Expand Down
7 changes: 7 additions & 0 deletions tripy/tripy/flat_ir/ops/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def to_mlir(self, operands):
value=self.dim,
)

# Concatenation of a single operand is a no-op.
if len(operands) == 1:
return [operands[0]]

# TODO https://github.com/NVIDIA/TensorRT-Incubator/issues/70: if we could use numpy here, then we could implement the constant folding.
# Otherwise, implement a fold method in MLIR-TRT.

output = stablehlo.concatenate(operands, dimension=concatenate_dim)
# overwrite output type if its shape is inferred
if self.outputs[0].shape is not None:
Expand Down
6 changes: 4 additions & 2 deletions tripy/tripy/flat_ir/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from mlir_tensorrt.compiler import ir
from mlir_tensorrt.compiler.dialects import stablehlo
from mlir_tensorrt.compiler.dialects._ods_common import get_op_result_or_value

from tripy.flat_ir.ops.base import BaseFlatIROp

Expand All @@ -31,7 +32,8 @@ class DynamicGatherOp(BaseFlatIROp):
def to_mlir(self, operands):
index_dims = self.inputs[1].rank
# Ensure slice_sizes is a static tensor with the same shape as the input.
operands[2].set_type(ir.RankedTensorType.get([self.inputs[0].rank], operands[2].type.element_type))
slice_sizes = get_op_result_or_value(operands[2])
slice_sizes.set_type(ir.RankedTensorType.get([self.inputs[0].rank], slice_sizes.type.element_type))
offset_dims = list(range(self.axis)) + list(range(self.axis + index_dims, self.inputs[0].rank + index_dims - 1))
index_vector_dim = self.inputs[1].rank

Expand All @@ -49,6 +51,6 @@ def to_mlir(self, operands):
)

gather_out = stablehlo.dynamic_gather(
operand=operands[0], start_indices=operands[1], dimension_numbers=attr, slice_sizes=operands[2]
operand=operands[0], start_indices=operands[1], dimension_numbers=attr, slice_sizes=slice_sizes
)
return [gather_out]
27 changes: 25 additions & 2 deletions tripy/tripy/flat_ir/ops/get_dimension_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,42 @@

from mlir_tensorrt.compiler import ir
from mlir_tensorrt.compiler.dialects import stablehlo
from mlir_tensorrt.compiler.dialects._ods_common import get_op_result_or_value


from tripy.flat_ir.ops.base import BaseFlatIROp

# TODO: this should go in top-level Tripy config?
DIM_TENSOR_BITWIDTH = 32


@dataclass(repr=False)
class GetDimensionSizeOp(BaseFlatIROp):

dim: int

def to_mlir(self, operands):
inp = operands[0]
inp = get_op_result_or_value(operands[0])

inp_type = ir.RankedTensorType(inp.type)
assert self.dim < inp_type.rank, "expected dim to be less than rank"
dim_int_type = ir.IntegerType.get_signless(DIM_TENSOR_BITWIDTH)

# If we can view the type of the tensor and the dimension is static,
# then just materialize a constant operation.
if not ir.ShapedType.is_dynamic_size(inp_type.shape[self.dim]):
result = stablehlo.constant(
ir.DenseIntElementsAttr.get_splat(
ir.RankedTensorType.get([], dim_int_type),
ir.IntegerAttr.get(dim_int_type, inp_type.shape[self.dim]),
)
)
return [result]

# otherwise, create `stablehlo.get_dimension_size`
dim_attr = ir.IntegerAttr.get(
type=ir.IntegerType.get_signless(64),
value=self.dim,
)
return [stablehlo.get_dimension_size(inp, dimension=dim_attr)]
result = stablehlo.get_dimension_size(inp, dimension=dim_attr)
return [result]
29 changes: 26 additions & 3 deletions tripy/tripy/flat_ir/ops/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,39 @@
# limitations under the License.
#

import array
from dataclasses import dataclass
from typing import Optional, Sequence

from mlir_tensorrt.compiler.dialects import stablehlo
from mlir_tensorrt.compiler.dialects._ods_common import get_op_result_or_value
from mlir_tensorrt.compiler import ir

from tripy.flat_ir.ops.base import BaseFlatIROp
from tripy.backend.mlir.utils import is_any_dim_dynamic
from tripy.backend.mlir.utils import is_any_dim_dynamic, get_constant_value
import tripy.utils.utils as utils


def _do_static_reshape(arg, new_shape: Sequence[int]):
# If the input is a constant, then just reshape the constant.
const_input = get_constant_value(arg)

# For now, just handle i32 types since we don't have the convenience of numpy, we need
# to handle each element type a different way using 'array.array'.
if const_input and ir.IntegerType.get_signless(32) == const_input.type.element_type:
new_type = ir.RankedTensorType.get(new_shape, const_input.type.element_type)
new_attr = ir.DenseElementsAttr.get(array=array.array("i", const_input), type=new_type)
return stablehlo.constant(new_attr)

arg = get_op_result_or_value(arg)
output_type = ir.RankedTensorType.get(new_shape, arg.type.element_type)
return stablehlo.reshape(output_type, arg)


@dataclass(repr=False)
class ReshapeOp(BaseFlatIROp):
def to_mlir(self, operands):
output = stablehlo.ReshapeOp(result=self.outputs[0].to_mlir(), operand=operands[0])
return [output]
return [_do_static_reshape(operands[0], self.outputs[0].to_mlir().shape)]


class DynamicReshapeOp(BaseFlatIROp):
Expand All @@ -41,6 +59,11 @@ def to_mlir(self, operands):
self.inputs[1].shape = new_shape
operands[1].set_type(ir.RankedTensorType.get(new_shape, operands[1].type.element_type))

# If the shape is a constant, then we can just do static reshape.
const_shape_value = get_constant_value(operands[1])
if const_shape_value:
return [_do_static_reshape(operands[0], list(const_shape_value))]

output = stablehlo.dynamic_reshape(
result=self.outputs[0].to_mlir(), operand=operands[0], output_shape=operands[1]
)
Expand Down
3 changes: 2 additions & 1 deletion tripy/tripy/flat_ir/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from mlir_tensorrt.compiler import ir
from mlir_tensorrt.compiler.dialects import stablehlo
from mlir_tensorrt.compiler.dialects._ods_common import get_op_result_or_value

from tripy.flat_ir.ops.base import BaseFlatIROp
from tripy.backend.mlir.utils import is_any_dim_dynamic
Expand Down Expand Up @@ -52,7 +53,7 @@ def to_mlir(self, operands):
if any(dynamic_dim_attrs):
assert static_dim_attrs, "DynamicSliceOp requires at-least 1 attribute to be of static shape."
for d in dynamic_dim_attrs:
new_shape = [s for s in static_dim_attrs[0].type.shape]
new_shape = [s for s in get_op_result_or_value(static_dim_attrs[0]).type.shape]
d.set_type(ir.RankedTensorType.get(new_shape, d.type.element_type))

return [
Expand Down