Skip to content

Commit

Permalink
change softmax lower
Browse files Browse the repository at this point in the history
  • Loading branch information
weilinquan committed Jan 2, 2024
1 parent f2ef534 commit b6e812b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 202 deletions.
1 change: 0 additions & 1 deletion frontend/Python/graph/op_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ class GetItemOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.GetItemType
self._lower_strategy = []


class OutputOp(Op):
Expand Down
232 changes: 31 additions & 201 deletions frontend/Python/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ def neg_op(
):
"""
Import the tensor neg operation.
From buddy NegOp to MLIR linalg `matmul` operation.
From buddy NegOp to MLIR linalg `negf` operation.
Note: This op, compute input node's neg result.
Args:
Expand Down Expand Up @@ -1480,27 +1480,9 @@ def batch_matmul_op(
dtype = node.tensor_meta["dtype"]
mlir_dtype = mlir_element_type_get(dtype)
tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
output = tensor.EmptyOp(output_shape, mlir_dtype)
# use linalg.generic implementation
generic_map = ir.AffineMap.get_permutation([0, 1, 2])
zero_fill = linalg.GenericOp(
[tensor_type],
[],
[output],
ir.ArrayAttr.get(
[ir.AffineMapAttr.get(generic_map.get_submap([0, 1, 2]))]
),
ir.ArrayAttr.get(
[ir.Attribute.parse("#linalg.iterator_type<parallel>")] * 3
),
)
block = ir.Block.create_at_start(
zero_fill.region,
[ir.RankedTensorType(output.result.type).element_type],
)
zero_op = arith.ConstantOp(mlir_dtype, mlir_element_attr_get(dtype, 0))
block.append(zero_op)
block.append(linalg.YieldOp([zero_op.result]))
element = mlir_element_attr_get(dtype, 0)
attr = ir.DenseElementsAttr.get_splat(tensor_type, element)
zero_fill = arith.ConstantOp(tensor_type, attr).result
op = linalg.batch_matmul(input1, input2, outs=[zero_fill.result])

return op
Expand Down Expand Up @@ -1586,200 +1568,49 @@ def softmax_op(
# ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim),
# )
# print(op, flush=True)
max_tensor_shape = copy.deepcopy(output_shape)
max_tensor_shape[dim] = 1
max_tensor_type = ir.RankedTensorType.get(max_tensor_shape, mlir_dtype)
max_tensor = tensor.EmptyOp(max_tensor_shape, mlir_dtype)
max_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(max_tensor_shape))
]
max_tensor_map = ir.AffineMap.get(len(max_tensor_shape), 0, max_tensor_map)
neg_inf_fill = linalg.GenericOp(
[max_tensor_type],
[],
[max_tensor],
ir.ArrayAttr.get([ir.AffineMapAttr.get(max_tensor_map)]),
ir.ArrayAttr.get(
[ir.Attribute.parse("#linalg.iterator_type<parallel>")]
* len(max_tensor_shape)
),
)
block = ir.Block.create_at_start(
neg_inf_fill.region,
[ir.RankedTensorType(max_tensor.result.type).element_type],
)
neg_inf_op = arith.ConstantOp(
mlir_dtype, mlir_element_attr_get(dtype, float("-inf"))
)
block.append(neg_inf_op)
block.append(linalg.YieldOp([neg_inf_op.result]))
sum_tensor_shape = copy.deepcopy(output_shape)
sum_tensor_shape[dim] = 1
sum_tensor_type = ir.RankedTensorType.get(sum_tensor_shape, mlir_dtype)
element = mlir_element_attr_get(dtype, 0)
attr = ir.DenseElementsAttr.get_splat(sum_tensor_type, element)
sum_tensor = arith.ConstantOp(sum_tensor_type, attr).result
input1_map = [ir.AffineExpr.get_dim(i) for i in range(len(output_shape))]
input1_map = ir.AffineMap.get(len(output_shape), 0, input1_map)
max_tensor_map = [
sum_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
max_tensor_map[dim] = ir.AffineExpr.get_constant(0)
max_tensor_map = ir.AffineMap.get(len(output_shape), 0, max_tensor_map)
sum_tensor_map[dim] = ir.AffineExpr.get_constant(0)
sum_tensor_map = ir.AffineMap.get(len(output_shape), 0, sum_tensor_map)
loop_type = [ir.Attribute.parse("#linalg.iterator_type<parallel>")] * len(
output_shape
)
loop_type[dim] = ir.Attribute.parse("#linalg.iterator_type<reduction>")
max_tensor_op = linalg.GenericOp(
[max_tensor_type],
sum_tensor_op = linalg.GenericOp(
[sum_tensor_type],
[input1],
[neg_inf_fill],
[sum_tensor],
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(input1_map),
ir.AffineMapAttr.get(max_tensor_map),
ir.AffineMapAttr.get(sum_tensor_map),
]
),
ir.ArrayAttr.get(loop_type),
)
block = ir.Block.create_at_start(
max_tensor_op.region,
sum_tensor_op.region,
[
ir.RankedTensorType(input1.type).element_type,
ir.RankedTensorType(neg_inf_fill.result.type).element_type,
mlir_dtype,
mlir_dtype,
],
)
max_op = arith.MaximumFOp(block.arguments[0], block.arguments[1])
block.append(max_op)
block.append(linalg.YieldOp([max_op.result]))
exp_tensor = tensor.EmptyOp(output_shape, mlir_dtype)
exp_tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
input1_map = [ir.AffineExpr.get_dim(i) for i in range(len(output_shape))]
input1_map = ir.AffineMap.get(len(output_shape), 0, input1_map)
max_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
max_tensor_map[dim] = ir.AffineExpr.get_constant(0)
max_tensor_map = ir.AffineMap.get(len(output_shape), 0, max_tensor_map)
exp_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
exp_tensor_map = ir.AffineMap.get(len(output_shape), 0, exp_tensor_map)
exp_tensor_op = linalg.GenericOp(
[exp_tensor_type],
[input1, max_tensor_op.result],
[exp_tensor],
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(input1_map),
ir.AffineMapAttr.get(max_tensor_map),
ir.AffineMapAttr.get(exp_tensor_map),
]
),
ir.ArrayAttr.get(
[ir.Attribute.parse("#linalg.iterator_type<parallel>")]
* len(output_shape)
),
)
block = ir.Block.create_at_start(
exp_tensor_op.region,
[
ir.RankedTensorType(input1.type).element_type,
ir.RankedTensorType(max_tensor_op.result.type).element_type,
ir.RankedTensorType(exp_tensor.result.type).element_type,
],
)
if str(mlir_dtype).find("i") != -1:
sub_op = arith.SubIOp(block.arguments[0], block.arguments[1])
else:
sub_op = arith.SubFOp(block.arguments[0], block.arguments[1])
exp_op = math.ExpOp(sub_op.result)
block.append(sub_op)
exp_op = math.ExpOp(block.arguments[0])
add_op = arith.AddFOp(exp_op.result, block.arguments[1])
block.append(exp_op)
block.append(linalg.YieldOp([exp_op.result]))
reduce_sum_tensor_shape = copy.deepcopy(output_shape)
reduce_sum_tensor_shape[dim] = 1
reduce_sum_tensor = tensor.EmptyOp(reduce_sum_tensor_shape, mlir_dtype)
reduce_sum_tensor_type = ir.RankedTensorType.get(
reduce_sum_tensor_shape, mlir_dtype
)
reduce_sum_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
reduce_sum_tensor_map = ir.AffineMap.get(
len(output_shape), 0, reduce_sum_tensor_map
)
zero_fill_op = linalg.GenericOp(
[reduce_sum_tensor_type],
[],
[reduce_sum_tensor.result],
ir.ArrayAttr.get([ir.AffineMapAttr.get(reduce_sum_tensor_map)]),
ir.ArrayAttr.get(
[ir.Attribute.parse("#linalg.iterator_type<parallel>")]
* len(output_shape)
),
)
block = ir.Block.create_at_start(
zero_fill_op.region,
[ir.RankedTensorType(reduce_sum_tensor.result.type).element_type],
)
zero_op = arith.ConstantOp(mlir_dtype, mlir_element_attr_get(dtype, 0))
block.append(zero_op)
block.append(linalg.YieldOp([zero_op.result]))
reduce_sum_tensor_shape = copy.deepcopy(output_shape)
reduce_sum_tensor_shape[dim] = 1
reduce_sum_tensor_type = ir.RankedTensorType.get(
reduce_sum_tensor_shape, mlir_dtype
)
exp_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
exp_tensor_map = ir.AffineMap.get(len(output_shape), 0, exp_tensor_map)
reduce_sum_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
reduce_sum_tensor_map[dim] = ir.AffineExpr.get_constant(0)
reduce_sum_tensor_map = ir.AffineMap.get(
len(output_shape), 0, reduce_sum_tensor_map
)
loop_type = [ir.Attribute.parse("#linalg.iterator_type<parallel>")] * len(
output_shape
)
loop_type[dim] = ir.Attribute.parse("#linalg.iterator_type<reduction>")
reduce_sum_tensor_op = linalg.GenericOp(
[reduce_sum_tensor_type],
[exp_tensor_op.result],
[zero_fill_op.result],
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(exp_tensor_map),
ir.AffineMapAttr.get(reduce_sum_tensor_map),
]
),
ir.ArrayAttr.get(loop_type),
)
block = ir.Block.create_at_start(
reduce_sum_tensor_op.region,
[
ir.RankedTensorType(exp_tensor_op.result.type).element_type,
ir.RankedTensorType(zero_fill_op.result.type).element_type,
],
)
if str(mlir_dtype).find("i") != -1:
add_op = arith.AddIOp(block.arguments[0], block.arguments[1])
else:
add_op = arith.AddFOp(block.arguments[0], block.arguments[1])
block.append(add_op)
block.append(linalg.YieldOp([add_op.result]))
reduce_sum_tensor_shape = copy.deepcopy(output_shape)
reduce_sum_tensor_shape[dim] = 1
result_tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
result_tensor = tensor.EmptyOp(output_shape, mlir_dtype)
exp_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
exp_tensor_map = ir.AffineMap.get(len(output_shape), 0, exp_tensor_map)
reduce_sum_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
reduce_sum_tensor_map[dim] = ir.AffineExpr.get_constant(0)
reduce_sum_tensor_map = ir.AffineMap.get(
len(output_shape), 0, reduce_sum_tensor_map
)
result_tensor_map = [
ir.AffineExpr.get_dim(i) for i in range(len(output_shape))
]
Expand All @@ -1788,12 +1619,12 @@ def softmax_op(
)
op = linalg.GenericOp(
[result_tensor_type],
[exp_tensor_op.result, reduce_sum_tensor_op.result],
[input1, sum_tensor_op.result],
[result_tensor.result],
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(exp_tensor_map),
ir.AffineMapAttr.get(reduce_sum_tensor_map),
ir.AffineMapAttr.get(input1_map),
ir.AffineMapAttr.get(sum_tensor_map),
ir.AffineMapAttr.get(result_tensor_map),
]
),
Expand All @@ -1805,15 +1636,14 @@ def softmax_op(
block = ir.Block.create_at_start(
op.region,
[
ir.RankedTensorType(exp_tensor_op.result.type).element_type,
ir.RankedTensorType(reduce_sum_tensor_op.result.type).element_type,
ir.RankedTensorType(result_tensor.result.type).element_type,
mlir_dtype,
mlir_dtype,
mlir_dtype,
],
)
if str(mlir_dtype).find("i") != -1:
div_op = arith.DivSIOp(block.arguments[0], block.arguments[1])
else:
div_op = arith.DivFOp(block.arguments[0], block.arguments[1])
exp_op = math.ExpOp(block.arguments[0])
div_op = arith.DivFOp(exp_op.result, block.arguments[1])
block.append(exp_op)
block.append(div_op)
block.append(linalg.YieldOp([div_op.result]))

Expand Down

0 comments on commit b6e812b

Please sign in to comment.