Skip to content

Commit

Permalink
Allow CustomKernelFusionRewriter to manualy specify the kernel index.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676336179
  • Loading branch information
derdrdirk authored and Google-ML-Automation committed Sep 19, 2024
1 parent 42b04a6 commit 45dea71
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 26 deletions.
30 changes: 15 additions & 15 deletions xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ TEST_F(CutlassFusionTest, RowMajorGemm) {
patterns.Emplace<CutlassGemmPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand Down Expand Up @@ -141,7 +141,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) {
patterns.Emplace<CutlassGemmWithUpcastPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand Down Expand Up @@ -183,7 +183,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastOfBothOperands) {
patterns.Emplace<CutlassGemmWithUpcastPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand All @@ -209,7 +209,7 @@ TEST_F(CutlassFusionTest, DoNotPatternMatchNotImplementedKernelTypes) {
ParseAndReturnVerifiedModule(hlo);

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);

ASSERT_FALSE(pass.Run(hlo_module.value().get()).value());
}
Expand Down Expand Up @@ -260,7 +260,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) {
patterns.Emplace<CutlassGemmWithDynamicUpdateSlicePattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand Down Expand Up @@ -319,7 +319,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) {
patterns.Emplace<CutlassGemmWithDynamicUpdateSlicePattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand Down Expand Up @@ -365,7 +365,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceWithoutBitcast) {
patterns.Emplace<CutlassGemmWithDynamicUpdateSlicePattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand Down Expand Up @@ -654,7 +654,7 @@ TEST_F(CutlassFusionTest, GemmWithUpcastShouldBeFused) {
patterns.Emplace<CutlassGemmWithUpcastPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-3, 1e-3}));
}
Expand All @@ -674,7 +674,7 @@ TEST_F(CutlassFusionTest,
patterns.Emplace<CutlassGemmWithUpcastPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
// Check that hlo is not rewritten after the pass, indicating that we don't
// match the upcast pattern.
RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt);
Expand All @@ -695,7 +695,7 @@ TEST_F(CutlassFusionTest,
patterns.Emplace<CutlassGemmWithUpcastPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
// Check that hlo is not rewritten after the pass, indicating that we don't
// match the upcast pattern.
RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt);
Expand All @@ -716,7 +716,7 @@ TEST_F(CutlassFusionTest,
patterns.Emplace<CutlassGemmWithUpcastPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
// Check that hlo is not rewritten after the pass, indicating that we don't
// match the upcast pattern.
RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt);
Expand All @@ -736,7 +736,7 @@ TEST_F(CutlassFusionTest,
CustomKernelFusionPatternRegistry patterns;
patterns.Emplace<CutlassGemmWithUpcastPattern>();
auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
// Check that hlo is not rewritten after the pass, indicating that we don't
// match the upcast pattern.
RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt);
Expand All @@ -756,7 +756,7 @@ TEST_F(CutlassFusionTest,
CustomKernelFusionPatternRegistry patterns;
patterns.Emplace<CutlassGemmWithUpcastPattern>();
auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
// Check that hlo is not rewritten after the pass, indicating that we don't
// match the upcast pattern.
RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt);
Expand All @@ -776,7 +776,7 @@ TEST_F(CutlassFusionTest, GemmWithUpcastWithBatchDimensionShouldNotBeFused) {
CustomKernelFusionPatternRegistry patterns;
patterns.Emplace<CutlassGemmWithUpcastPattern>();
auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
// Check that hlo is not rewritten after the pass, indicating that we don't
// match the upcast pattern.
RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt);
Expand All @@ -795,7 +795,7 @@ TEST_F(CutlassFusionTest, GemmWithUpcastAndColumnMajorOperandsShouldBeFused) {
CustomKernelFusionPatternRegistry patterns;
patterns.Emplace<CutlassGemmWithUpcastPattern>();
auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
std::string expected = "CHECK: cutlass_gemm";
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-3, 1e-3}));
Expand Down
19 changes: 11 additions & 8 deletions xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ limitations under the License.
namespace xla::gpu {

CustomKernelFusionRewriter::CustomKernelFusionRewriter(
const se::DeviceDescription* device,
const se::DeviceDescription* device, int kernel_index,
const CustomKernelFusionPatternRegistry* patterns)
: device_(device), patterns_(patterns) {}
: device_(device), kernel_index_(kernel_index), patterns_(patterns) {}

// Returns a set of instruction that have users outside of a matched pattern
// and have a replacement that must be applied after building a new custom
Expand Down Expand Up @@ -148,9 +148,11 @@ static absl::StatusOr<HloComputation*> CreateFusionBody(
return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
}

static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
namespace {
absl::StatusOr<HloInstruction*> CreateFusionInstruction(
HloModule* module, const CustomKernelFusionPattern::Match& match,
absl::Span<HloInstruction* const> captures, HloComputation* body) {
absl::Span<HloInstruction* const> captures, HloComputation* body,
int kernel_index) {
// We'll be replacing the root operation of a custom kernel fusion with a
// fusion instruction calling fusion computation.
HloInstruction* root = match.root();
Expand All @@ -168,7 +170,7 @@ static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
*gpu_config.mutable_fusion_backend_config();
backend_config.set_kind("__custom_fusion");
*backend_config.mutable_custom_fusion_config() = match.config();
backend_config.mutable_custom_fusion_config()->set_kernel_index(0);
backend_config.mutable_custom_fusion_config()->set_kernel_index(kernel_index);
TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));

// If we don't have workspace we can return constructed fusion instruction.
Expand All @@ -178,6 +180,7 @@ static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
return parent->AddInstruction(
HloInstruction::CreateGetTupleElement(fusion, 0));
}
} // namespace

absl::StatusOr<bool> CustomKernelFusionRewriter::Run(
HloModule* module,
Expand Down Expand Up @@ -205,9 +208,9 @@ absl::StatusOr<bool> CustomKernelFusionRewriter::Run(

TF_ASSIGN_OR_RETURN(HloComputation * fusion_body,
CreateFusionBody(module, match, captures));
TF_ASSIGN_OR_RETURN(
HloInstruction * fusion,
CreateFusionInstruction(module, match, captures, fusion_body));
TF_ASSIGN_OR_RETURN(HloInstruction * fusion,
CreateFusionInstruction(module, match, captures,
fusion_body, kernel_index_));

VLOG(2) << "Added a fusion instruction: " << fusion->name()
<< " for custom kernel fusion " << match.config().name()
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace xla::gpu {
class CustomKernelFusionRewriter : public HloModulePass {
public:
explicit CustomKernelFusionRewriter(
const se::DeviceDescription* device,
const se::DeviceDescription* device, int kernel_index = 0,
const CustomKernelFusionPatternRegistry* patterns =
CustomKernelFusionPatternRegistry::Default());

Expand All @@ -78,6 +78,7 @@ class CustomKernelFusionRewriter : public HloModulePass {

private:
const se::DeviceDescription* device_;
const int kernel_index_;
const CustomKernelFusionPatternRegistry* patterns_;
};

Expand Down
24 changes: 22 additions & 2 deletions xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,30 @@ TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) {
patterns.Emplace<SimpleGemmPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

TEST_F(CustomKernelFusionRewriterTest, SetsKernelIndex) {
const char* hlo = R"(
HloModule test
ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
%p0 = f16[15,19]{1,0} parameter(0)
%p1 = f16[19,17]{1,0} parameter(1)
ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";

CustomKernelFusionPatternRegistry patterns;
patterns.Emplace<SimpleGemmPattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/1, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), "CHECK: \"kernel_index\":1");
}

TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) {
const char* hlo = R"(
HloModule test
Expand Down Expand Up @@ -131,7 +151,7 @@ TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) {
patterns.Emplace<SimpleGemmPattern>(1024);

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

Expand Down

0 comments on commit 45dea71

Please sign in to comment.