diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index ef948e8a43115..5825878005729 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -402,6 +402,7 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc index cb9b974b3af63..4ec9f347dff90 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -367,14 +367,15 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( const HloInstruction* hlo = tiled_hlo->hlo(); - if (hlo->opcode() == HloOpcode::kConcatenate) { - // TODO(b/351342921): Add propagation of the number of blocks that read or - // compute a tile. Concatenate is the only operation that may change that. - return absl::FailedPreconditionError( - "Concatenate is not supported by the indexing cost model."); - } - if (fusion_adaptor.ContainsInstruction(hlo)) { + if (hlo->opcode() == HloOpcode::kConcatenate) { + // TODO(b/351342921): Add propagation of the number of blocks that read + // or compute a tile. Concatenate is the only operation that may change + // that. + return absl::FailedPreconditionError( + "Concatenate is not supported by the indexing cost model."); + } + // Total number of elements computed for this tile across all blocks. // // Even if real `tile_size` is smaller than `padded_tile_size`, SM will diff --git a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index d9a2ab7b1e2bf..5b9f6d4939b62 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -385,6 +386,38 @@ ENTRY main { EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 185, 1); } +// TODO(b/351342921): Remove this test once there is no special filter for +// concatenate in Cost Model. +TEST_F(GpuIndexingPerformanceModelTest, + EstimateRunTimeForTiledFusion_ConcatenateOperandIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +fusion { + param_0 = f32[32,64] parameter(0) + param_1 = f32[32,64] parameter(1) + ROOT subtract = f32[32,64] subtract(param_0, param_1) +} + +ENTRY main { + param_0 = f32[32,16] parameter(0) + param_1 = f32[32,48] parameter(1) + param_2 = f32[32,64] parameter(2) + concatenate = f32[32,64] concatenate(param_0, param_1), dimensions={1} + ROOT fusion = f32[32,64] fusion(concatenate, param_2), kind=kCustom, calls=fusion +})")); + + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + LaunchDimensions launch_dimensions{8, WarpSize()}; + + auto result = indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{16, 16}); + + TF_EXPECT_OK(result.status()); +} + TEST_F(GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_ConcatenateIsNotSupported) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(