diff --git a/xla/service/gpu/transforms/gemm_fusion.cc b/xla/service/gpu/transforms/gemm_fusion.cc index 9d0c347f82ed8..656a338b19f06 100644 --- a/xla/service/gpu/transforms/gemm_fusion.cc +++ b/xla/service/gpu/transforms/gemm_fusion.cc @@ -198,13 +198,19 @@ std::optional GetOperandDimOrdersAndCombinedReqs( DimOrdersAndReqsOrError dim_orders_and_new_reqs = GetPropagatedDimOrdersAndRequirements( hlo, dim_order, TransformDirection::kOutputToInput, properties); - if (!std::holds_alternative(dim_orders_and_new_reqs)) { + if (std::holds_alternative(dim_orders_and_new_reqs)) { + VLOG(5) << "Not fusing " << hlo.ToString() + << " to the output due to the decision: " + << std::get(dim_orders_and_new_reqs).Explain(); return std::nullopt; } DotRequirementsOrError combined_reqs = CombineDotRequirements( requirements, std::get(dim_orders_and_new_reqs).requirements); - if (!std::holds_alternative(combined_reqs)) { + if (std::holds_alternative(combined_reqs)) { + VLOG(5) << "Not fusing " << hlo.ToString() + << " to the output due to the decision: " + << std::get(combined_reqs).Explain(); return std::nullopt; } return DimOrdersAndReqs{ @@ -223,13 +229,19 @@ std::optional GetOperandDimOrdersAndCombinedReqsIfProfitable( hlo, TransformDirection::kOutputToInput, /*src_operand_index=*/std::nullopt, dim_order, gpu_version, properties); - if (!std::holds_alternative(dim_orders_and_new_reqs)) { + if (std::holds_alternative(dim_orders_and_new_reqs)) { + VLOG(5) << "Not fusing " << hlo.ToString() + << " to the output due to the decision: " + << std::get(dim_orders_and_new_reqs).Explain(); return std::nullopt; } DotRequirementsOrError combined_reqs = CombineDotRequirements( requirements, std::get(dim_orders_and_new_reqs).requirements); - if (!std::holds_alternative(combined_reqs)) { + if (std::holds_alternative(combined_reqs)) { + VLOG(5) << "Not fusing " << hlo.ToString() + << " to the output due to the decision: " + << std::get(combined_reqs).Explain(); return std::nullopt; } return DimOrdersAndReqs{ @@ -247,13 +259,19 @@ std::optional GetUserDimOrdersAndCombinedReqsIfProfitable( GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( user, TransformDirection::kInputToOutput, user.operand_index(&hlo), hlo_dim_order, gpu_version, properties); - if (!std::holds_alternative(dim_orders_and_new_reqs)) { + if (std::holds_alternative(dim_orders_and_new_reqs)) { + VLOG(5) << "Not fusing " << user.ToString() + << " to the input due to the decision: " + << std::get(dim_orders_and_new_reqs).Explain(); return std::nullopt; } DotRequirementsOrError combined_reqs = CombineDotRequirements( requirements, std::get(dim_orders_and_new_reqs).requirements); - if (!std::holds_alternative(combined_reqs)) { + if (std::holds_alternative(combined_reqs)) { + VLOG(5) << "Not fusing " << user.ToString() + << " to the input due to the decision: " + << std::get(combined_reqs).Explain(); return std::nullopt; } return DimOrdersAndReqs{ diff --git a/xla/service/gpu/triton_tiling_propagation.cc b/xla/service/gpu/triton_tiling_propagation.cc index bf27eb9dd8a98..1943af5343b42 100644 --- a/xla/service/gpu/triton_tiling_propagation.cc +++ b/xla/service/gpu/triton_tiling_propagation.cc @@ -1023,7 +1023,10 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( DimOrdersAndReqsOrError result_or_error = GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order, transform_direction, properties); - if (!std::holds_alternative(result_or_error)) { + if (std::holds_alternative(result_or_error)) { + VLOG(5) << "Not fusing " << hlo.ToString() + << " to the output due to the decision: " + << std::get(result_or_error).Explain(); return result_or_error; } DimOrdersAndReqs dim_orders_and_requirements =