Skip to content

Commit

Permalink
[XLA:GPU] Fix a bug in Cost Model that doesn't allow concatenate as o…
Browse files Browse the repository at this point in the history
…perands.

PiperOrigin-RevId: 675490471
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Sep 17, 2024
1 parent a739fcd commit e51e376
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ xla_cc_test(
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
Expand Down
15 changes: 8 additions & 7 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,15 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation(

const HloInstruction* hlo = tiled_hlo->hlo();

if (hlo->opcode() == HloOpcode::kConcatenate) {
// TODO(b/351342921): Add propagation of the number of blocks that read or
// compute a tile. Concatenate is the only operation that may change that.
return absl::FailedPreconditionError(
"Concatenate is not supported by the indexing cost model.");
}

if (fusion_adaptor.ContainsInstruction(hlo)) {
if (hlo->opcode() == HloOpcode::kConcatenate) {
// TODO(b/351342921): Add propagation of the number of blocks that read
// or compute a tile. Concatenate is the only operation that may change
// that.
return absl::FailedPreconditionError(
"Concatenate is not supported by the indexing cost model.");
}

// Total number of elements computed for this tile across all blocks.
//
// Even if real `tile_size` is smaller than `padded_tile_size`, SM will
Expand Down
33 changes: 33 additions & 0 deletions xla/service/gpu/model/gpu_indexing_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/test_helpers.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -385,6 +386,38 @@ ENTRY main {
EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 185, 1);
}

// TODO(b/351342921): Remove this test once there is no special filter for
// concatenate in Cost Model.
TEST_F(GpuIndexingPerformanceModelTest,
EstimateRunTimeForTiledFusion_ConcatenateOperandIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
fusion {
param_0 = f32[32,64] parameter(0)
param_1 = f32[32,64] parameter(1)
ROOT subtract = f32[32,64] subtract(param_0, param_1)
}
ENTRY main {
param_0 = f32[32,16] parameter(0)
param_1 = f32[32,48] parameter(1)
param_2 = f32[32,64] parameter(2)
concatenate = f32[32,64] concatenate(param_0, param_1), dimensions={1}
ROOT fusion = f32[32,64] fusion(concatenate, param_2), kind=kCustom, calls=fusion
})"));

auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

LaunchDimensions launch_dimensions{8, WarpSize()};

auto result = indexing_cost_model_.EstimateRunTimeForTiledFusion(
*fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{16, 16});

TF_EXPECT_OK(result.status());
}

TEST_F(GpuIndexingPerformanceModelTest,
EstimateRunTimeForTiledFusion_ConcatenateIsNotSupported) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
Expand Down

0 comments on commit e51e376

Please sign in to comment.