Skip to content

Commit

Permalink
[examples][MLIRPython] Fix var_mean_op and remove unused operator
Browse files Browse the repository at this point in the history
  • Loading branch information
xTayEx committed Oct 14, 2023
1 parent 521ce61 commit dd326d5
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions examples/MLIRPython/buddy/operators_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def addmm_op(node: torch.fx.Node,
add_result_tensor_type = ir.RankedTensorType.get(final_result_shape,
result_element_type)


op = _gen_arith_binary_op(input_, matmul_result_reshape_op.result, tosa.AddOp)
return op

Expand All @@ -128,12 +127,6 @@ def bmm_op(node: torch.fx.Node,
return op


def gt_op(node, symbol_table):
input1 = symbol_table.get((str(node.args[0]), 0), node.args[0])
input2 = symbol_table.get((str(node.args[1]), 0), node.args[1])
return _gen_arith_binary_op(input1, input2, tosa.GreaterOp)


def add_op(node, symbol_table):
input1 = symbol_table.get((str(node.args[0]), 0), node.args[0])
input2 = symbol_table.get((str(node.args[1]), 0), node.args[1])
Expand Down Expand Up @@ -253,7 +246,6 @@ def unsqueeze_op(node, symbol_table):
sizes.insert(dim, 1)
new_shape_content = array.array("i", sizes)
new_shape_content = memoryview(new_shape_content)
# new_shape_attr = ir.DenseElementsAttr.get(new_shape_content)
op = tosa.ReshapeOp(input_tensor, new_shape_content)
return op

Expand Down Expand Up @@ -361,7 +353,7 @@ def mean_dim_op(_input_tensor: ir.Value, _dim) -> ir.Operation:
reduce_sum_result = _input_tensor

# reduce along each dimension in `_dim`
for _dim_item, _ in enumerate(_dim):
for _dim_item in _dim:
reduce_dim_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64),
_dim_item)
reduce_sum_op: ir.Operation = tosa.ReduceSumOp(reduce_sum_result,
Expand Down Expand Up @@ -408,7 +400,7 @@ def var_dim_op(_input_tensor: ir.Value, _mean_tensor: ir.Value, _dim,

# the result of `mul_op` is the first tensor we need to reduce
reduce_sum_op = mul_op
for _dim_item, _ in enumerate(_dim):
for _dim_item in _dim:
reduce_dim_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64),
_dim_item)
reduce_sum_op: ir.Operation = tosa.ReduceSumOp(reduce_sum_op.results[0],
Expand Down Expand Up @@ -438,7 +430,7 @@ def var_dim_op(_input_tensor: ir.Value, _mean_tensor: ir.Value, _dim,
var_input_tensor = symbol_table.get((str(node.args[0]), 0))

kwargs = node.kwargs
keep_dim = kwargs.get("keep_dim", False)
keepdim = kwargs.get("keepdim", False)
correction = kwargs.get("correction", 1.0)

mean_op = None
Expand All @@ -447,14 +439,14 @@ def var_dim_op(_input_tensor: ir.Value, _mean_tensor: ir.Value, _dim,
calc_dims = range(len(ir.RankedTensorType(mean_input_tensor.type).shape))
else:
calc_dims = node.args[1]

mean_op = mean_dim_op(mean_input_tensor, calc_dims)
var_op = var_dim_op(var_input_tensor, mean_op.results[0], calc_dims,
correction)
mean_input_tensor = mean_op.results[0]
var_input_tensor = var_op.results[0]

if not keep_dim:
if not keepdim:
result_shp = ir.RankedTensorType(var_op.results[0].type).shape
result_shp = [siz for siz in result_shp if siz != 1]
var_op = tosa.ReshapeOp(var_op.results[0],
Expand Down Expand Up @@ -583,4 +575,4 @@ def sum_op(node, symbol_table):
"convert_element_type.default": convert_element_type_op,
"permute.default": permute_op,
"unsqueeze.default": unsqueeze_op,
}
}

0 comments on commit dd326d5

Please sign in to comment.