Skip to content

Commit

Permalink
Fix gpu_sparse_dot_test_gpu_a100 by including argument materialization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 659485091
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Aug 5, 2024
1 parent b7c2cc3 commit 2871885
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions third_party/triton/xla_extensions/sparsity_layout.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conv
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> std::optional<Value> {
- llvm_unreachable("Argument rematerialization should not happen in Triton "
- "-> TritonGPU conversion");
+ // TODO(b/354860562): reenable or remove.
+ // llvm_unreachable("Argument rematerialization should not happen in Triton "
+ // "-> TritonGPU conversion");
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining arguments that have been converted to a new type.
+ // We use this to rewrite triton_gpu.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return std::nullopt;
});

@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Expand All @@ -31,7 +32,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert
@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
Expand Down

0 comments on commit 2871885

Please sign in to comment.