Skip to content

Commit

Permalink
Add support for conversion fp16 to fp32 using ExtFOp
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored May 11, 2024
1 parent 161f7a4 commit 5d76f19
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 6 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,12 @@ class BlockedToMMA : public mlir::RewritePattern {

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
.cloneWith(std::nullopt, promotedType);
RankedTensorType tensor = cast<RankedTensorType>(operand.getType());
Type tensorElementType = tensor.getElementType();
Type tensorPromotedType = tensor.cloneWith(std::nullopt, promotedType);
if (tensorElementType.isF16() && promotedType.isF32()) {
return builder.create<arith::ExtFOp>(loc, tensorPromotedType, operand);
}
return builder.create<tt::FpToFpOp>(loc, tensorPromotedType, operand);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,12 @@ struct ConvertTritonGPUToLLVM

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
.cloneWith(std::nullopt, promotedType);
RankedTensorType tensor = cast<RankedTensorType>(operand.getType());
Type tensorElementType = tensor.getElementType();
Type tensorPromotedType = tensor.cloneWith(std::nullopt, promotedType);
if (tensorElementType.isF16() && promotedType.isF32()) {
return builder.create<arith::ExtFOp>(loc, tensorPromotedType, operand);
}
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
}
};
Expand Down

0 comments on commit 5d76f19

Please sign in to comment.