Skip to content

Commit

Permalink
[XLA:GPU] Add some logging with the fusing decisions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675164327
  • Loading branch information
loislo authored and Google-ML-Automation committed Sep 16, 2024
1 parent efab41b commit dc5d8b0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
30 changes: 24 additions & 6 deletions xla/service/gpu/transforms/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,19 @@ std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqs(
DimOrdersAndReqsOrError dim_orders_and_new_reqs =
GetPropagatedDimOrdersAndRequirements(
hlo, dim_order, TransformDirection::kOutputToInput, properties);
if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
if (std::holds_alternative<FusionDecision>(dim_orders_and_new_reqs)) {
VLOG(5) << "Not fusing " << hlo.ToString()
<< " to the output due to the decision: "
<< std::get<FusionDecision>(dim_orders_and_new_reqs).Explain();
return std::nullopt;
}
DotRequirementsOrError combined_reqs = CombineDotRequirements(
requirements,
std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
if (std::holds_alternative<FusionDecision>(combined_reqs)) {
VLOG(5) << "Not fusing " << hlo.ToString()
<< " to the output due to the decision: "
<< std::get<FusionDecision>(combined_reqs).Explain();
return std::nullopt;
}
return DimOrdersAndReqs{
Expand All @@ -223,13 +229,19 @@ std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqsIfProfitable(
hlo, TransformDirection::kOutputToInput,
/*src_operand_index=*/std::nullopt, dim_order, gpu_version,
properties);
if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
if (std::holds_alternative<FusionDecision>(dim_orders_and_new_reqs)) {
VLOG(5) << "Not fusing " << hlo.ToString()
<< " to the output due to the decision: "
<< std::get<FusionDecision>(dim_orders_and_new_reqs).Explain();
return std::nullopt;
}
DotRequirementsOrError combined_reqs = CombineDotRequirements(
requirements,
std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
if (std::holds_alternative<FusionDecision>(combined_reqs)) {
VLOG(5) << "Not fusing " << hlo.ToString()
<< " to the output due to the decision: "
<< std::get<FusionDecision>(combined_reqs).Explain();
return std::nullopt;
}
return DimOrdersAndReqs{
Expand All @@ -247,13 +259,19 @@ std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
user, TransformDirection::kInputToOutput, user.operand_index(&hlo),
hlo_dim_order, gpu_version, properties);
if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
if (std::holds_alternative<FusionDecision>(dim_orders_and_new_reqs)) {
VLOG(5) << "Not fusing " << user.ToString()
<< " to the input due to the decision: "
<< std::get<FusionDecision>(dim_orders_and_new_reqs).Explain();
return std::nullopt;
}
DotRequirementsOrError combined_reqs = CombineDotRequirements(
requirements,
std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
if (std::holds_alternative<FusionDecision>(combined_reqs)) {
VLOG(5) << "Not fusing " << user.ToString()
<< " to the input due to the decision: "
<< std::get<FusionDecision>(combined_reqs).Explain();
return std::nullopt;
}
return DimOrdersAndReqs{
Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,10 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
DimOrdersAndReqsOrError result_or_error =
GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order,
transform_direction, properties);
if (!std::holds_alternative<DimOrdersAndReqs>(result_or_error)) {
if (std::holds_alternative<FusionDecision>(result_or_error)) {
VLOG(5) << "Not fusing " << hlo.ToString()
<< " to the output due to the decision: "
<< std::get<FusionDecision>(result_or_error).Explain();
return result_or_error;
}
DimOrdersAndReqs dim_orders_and_requirements =
Expand Down

0 comments on commit dc5d8b0

Please sign in to comment.