Skip to content

Commit

Permalink
[XLA] Propagate the layout of layout constrained custom calls with hi…
Browse files Browse the repository at this point in the history
…gher

priority because they have no ability to accept another layout.

PiperOrigin-RevId: 674755625
  • Loading branch information
blakehechtman authored and Google-ML-Automation committed Sep 15, 2024
1 parent dedab4f commit af733ec
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
48 changes: 27 additions & 21 deletions xla/service/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,27 +741,6 @@ absl::Status LayoutAssignment::AddMandatoryConstraints(
}
}
}
} else if (IsLayoutConstrainedCustomCall(instruction)) {
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);

TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call,
/*mandatory=*/true, /*dfs=*/true,
/*allow_alias=*/true));
if (custom_call->IsCustomCall("LayoutConstraint")) {
TF_RETURN_IF_ERROR(
SetOperandLayout(custom_call->shape(), custom_call, 0));
} else {
for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
if (AnyOperandBufferForwarded(custom_call, i)) {
TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i))
<< "Partial alias of an operand is not supported";
} else {
TF_RETURN_IF_ERROR(SetOperandLayout(
custom_call->operand_shapes_with_layout()[i], custom_call, i));
}
}
}
} else if (IsLayoutConstrainedCollective(instruction)) {
TF_RETURN_IF_ERROR(
SetInstructionLayout(instruction->shape(), instruction));
Expand Down Expand Up @@ -2476,6 +2455,33 @@ absl::Status LayoutAssignment::RunOnComputation(
// Add any backend-specific constraints.
TF_RETURN_IF_ERROR(AddBackendConstraints(constraints));

for (HloInstruction* instruction :
constraints->computation()->MakeInstructionPostOrder()) {
if (!IsLayoutConstrainedCustomCall(instruction)) {
continue;
}
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);

TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call,
/*mandatory=*/true, /*dfs=*/true,
/*allow_alias=*/true));
if (custom_call->IsCustomCall("LayoutConstraint")) {
TF_RETURN_IF_ERROR(
SetOperandLayout(custom_call->shape(), custom_call, 0));
} else {
for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
if (AnyOperandBufferForwarded(custom_call, i)) {
TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i))
<< "Partial alias of an operand is not supported";
} else {
TF_RETURN_IF_ERROR(SetOperandLayout(
custom_call->operand_shapes_with_layout()[i], custom_call, i));
}
}
}
}

// Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(constraints));

Expand Down
36 changes: 36 additions & 0 deletions xla/service/layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,42 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3
ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
}

TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAndElementwise) {
const char* module_str = R"(
HloModule CustomCallLayoutConstrained
ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
p0 = f32[4,4] parameter(0)
p1 = f32[2,3] parameter(1)
cc = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
ROOT e = f32[1,2,3,4] exponential(cc)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(F32, {4, 4}, {1, 0}));
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 3}, {1, 0}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
AssignLayouts(m.get(), &computation_layout);

// The custom call should be partially encapsulated in kCopy instructions
// because of the layout mismatches.
ASSERT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Copy(m::Exp(m::CustomCall(m::Copy(), m::Parameter())))));

const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0)->operand(0);
ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
}

TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) {
const char* module_str = R"(
HloModule customcall.4
Expand Down

0 comments on commit af733ec

Please sign in to comment.