Skip to content

Commit

Permalink
Layer norm fusion with type conversions of input, scale and bias.
Browse files Browse the repository at this point in the history
  • Loading branch information
philipphack committed Sep 17, 2024
1 parent 0c440a1 commit 2880fde
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 54 deletions.
137 changes: 83 additions & 54 deletions xla/service/gpu/transforms/cudnn_norm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,19 @@ using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>;
// HloInstruction:
// UniqueHloInstruction x;
// bool m = Match(
// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)),
// m::Sin(m::Op().WithPredicate(x.capture_and_verify))));
// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.CaptureOrVerifyFn())),
// m::Sin(m::Op().WithPredicate(x.CaptureOrVerifyFn()))));
// m is true and x.Instr() returns an HloInstruction pointer to the operand of
// cosine and sine iff HloInstruction *instr points to a division of a cosine by
// a sine that operate on the same instruction.
class UniqueHloInstruction {
public:
UniqueHloInstruction()
: is_set_(false), instr_(nullptr), capture_or_verify_() {}
: is_set_(false),
instr_(nullptr),
capture_or_verify_([this](const HloInstruction* instr) -> bool {
return CaptureOrVerify(const_cast<HloInstruction*>(instr));
}) {}
HloInstruction* Instr() const { return instr_; }
void SetInstr(HloInstruction* instr) {
is_set_ = true;
Expand All @@ -143,12 +147,7 @@ class UniqueHloInstruction {

// Returns a std::function for capturing or verifying an instruction using
// WithPredicate.
std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() {
if (!capture_or_verify_) {
capture_or_verify_ = [this](const HloInstruction* instr) -> bool {
return CaptureOrVerify(const_cast<HloInstruction*>(instr));
};
}
std::function<bool(const HloInstruction*)> CaptureOrVerifyFn() const {
return capture_or_verify_;
}

Expand Down Expand Up @@ -465,6 +464,16 @@ auto OptionalSupportedTransform(Pattern pattern) {
SupportedBitcastOrReshape(shared_subpattern), shared_subpattern);
}

// Broadcast with optional supported type conversion.
template <typename Pattern>
auto Broadcast(HloInstruction** bcast, Pattern pattern) {
auto shared_subpattern = m::SharedSubpattern(pattern);
return m::AnyOf<HloInstruction>(
SupportedConvert(m::Broadcast(bcast, pattern)),
m::Broadcast(bcast, SupportedConvert(pattern)),
m::Broadcast(bcast, pattern));
}

// Bitcast or reshape with optional supported type conversion and/or addition or
// removal of degenerate dimensions.
template <typename Pattern>
Expand Down Expand Up @@ -597,7 +606,7 @@ auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) {
.WithPredicate([](const HloInstruction* instr) {
return CalculatesExpectation(instr);
})
.WithPredicate(expectation->GetCaptureOrVerifyFn()));
.WithPredicate(expectation->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
shared_subpattern);
}
Expand All @@ -612,7 +621,7 @@ auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce,
.WithPredicate([](const HloInstruction* instr) {
return CalculatesExpectation(instr);
})
.WithPredicate(expectation->GetCaptureOrVerifyFn()));
.WithPredicate(expectation->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
shared_subpattern);
}
Expand All @@ -624,19 +633,19 @@ auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation,
return m::AnyOf<HloInstruction>(
Subtract(
Expectation(Square(OptionalSupportedTransform(
m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))),
Square(Expectation(expectation,
OptionalSupportedTransform(m::Op().WithPredicate(
x->GetCaptureOrVerifyFn())))))
.WithPredicate(variance->GetCaptureOrVerifyFn()),
m::Op().WithPredicate(x->CaptureOrVerifyFn())))),
Square(Expectation(
expectation, OptionalSupportedTransform(
m::Op().WithPredicate(x->CaptureOrVerifyFn())))))
.WithPredicate(variance->CaptureOrVerifyFn()),
Expectation(
Square(Subtract(
OptionalSupportedTransform(
m::Op().WithPredicate(x->GetCaptureOrVerifyFn())),
m::Op().WithPredicate(x->CaptureOrVerifyFn())),
Expectation(expectation,
OptionalSupportedTransform(m::Op().WithPredicate(
x->GetCaptureOrVerifyFn()))))))
.WithPredicate(variance->GetCaptureOrVerifyFn()));
OptionalSupportedTransform(
m::Op().WithPredicate(x->CaptureOrVerifyFn()))))))
.WithPredicate(variance->CaptureOrVerifyFn()));
}

// Reciprocal of the square root of variance + epsilon with optional broadcast.
Expand All @@ -647,7 +656,7 @@ auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x,
auto shared_subpattern = m::SharedSubpattern(Rsqrt(
norm_factor, AddAnyOrder(Variance(variance, expectation, x),
m::Broadcast(m::ConstantScalar().WithPredicate(
epsilon->GetCaptureOrVerifyFn())))));
epsilon->CaptureOrVerifyFn())))));
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
shared_subpattern);
}
Expand Down Expand Up @@ -696,10 +705,10 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) {

// Expectation fused into a layer norm Custom Call.
auto FusedExpectation(UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
1));
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
1));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}
Expand All @@ -708,21 +717,20 @@ auto FusedExpectation(UniqueHloInstruction* custom_call) {
auto FusedExpectation(UniqueHloInstruction* fused_expectation,
UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
1)
.WithPredicate(fused_expectation->GetCaptureOrVerifyFn()));
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
1)
.WithPredicate(fused_expectation->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}

// Norm factor fused into a layer norm Custom Call.
auto FusedNormFactor(UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
2));
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
2));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}
Expand All @@ -731,11 +739,10 @@ auto FusedNormFactor(UniqueHloInstruction* custom_call) {
auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor,
UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
2)
.WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn()));
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
2)
.WithPredicate(fused_norm_factor->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}
Expand Down Expand Up @@ -784,7 +791,7 @@ auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x,
};
return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation,
custom_call)))
.WithPredicate(x_center->GetCaptureOrVerifyFn())
.WithPredicate(x_center->CaptureOrVerifyFn())
.WithPredicate(capture_or_verify_x);
}

Expand All @@ -806,7 +813,7 @@ auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
reduce, MultiplyMultiplyAnyOrder(
XCenter(x, custom_call, norm_metadata),
m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)),
m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
m::Op().WithPredicate(dy->CaptureOrVerifyFn())));
}

// Product of XCenter and the scaled and broadcasted product of F0 and
Expand Down Expand Up @@ -872,7 +879,7 @@ auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale,
m::Broadcast(
BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))),
MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale),
m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
m::Op().WithPredicate(dy->CaptureOrVerifyFn())));
}

class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
Expand Down Expand Up @@ -902,13 +909,13 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
instr,
SubtractMultiplyAddAnyOrder(
OptionalSupportedTransform(
m::Op().WithPredicate(x.GetCaptureOrVerifyFn())),
m::Op().WithPredicate(x.CaptureOrVerifyFn())),
Expectation(&expectation, &reduce,
OptionalSupportedTransform(m::Op().WithPredicate(
x.GetCaptureOrVerifyFn()))),
OptionalSupportedTransform(
m::Op().WithPredicate(x.CaptureOrVerifyFn()))),
NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon),
m::Broadcast(&broadcast_scale, m::Op(&scale)),
m::Broadcast(&broadcast_bias, m::Op(&bias))))) {
Broadcast(&broadcast_scale, m::Op(&scale)),
Broadcast(&broadcast_bias, m::Op(&bias))))) {
#if CUDNN_VERSION < 8905
// Layer norm kernels are available with cuDNN 8.9.5 and above.
VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5.";
Expand Down Expand Up @@ -949,7 +956,30 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
}

// Verify the element types. The element types of input and output and the
// shapes of scale and bias must match.
// shapes of scale and bias must match. If a conversion to the type of the
// input is the only user of the output, set the output to the conversion.
// Similarly, if one of the users of the scale/bias is a conversion to the
// type of the bias/scale, set the scale/bias to the conversion.
if (instr->user_count() == 1 &&
instr->users()[0]->opcode() == HloOpcode::kConvert &&
ShapeUtil::SameElementType(instr->users()[0]->shape(),
x.Instr()->shape())) {
instr = instr->users()[0];
}
for (HloInstruction* scale_user : scale->users()) {
if (scale_user->opcode() == HloOpcode::kConvert &&
ShapeUtil::SameElementType(scale_user->shape(), bias->shape())) {
scale = scale_user;
break;
}
}
for (HloInstruction* bias_user : bias->users()) {
if (bias_user->opcode() == HloOpcode::kConvert &&
ShapeUtil::SameElementType(bias_user->shape(), scale->shape())) {
bias = bias_user;
break;
}
}
if (!CompatibleElementType(instr) || !CompatibleElementType(scale) ||
!CompatibleElementType(bias) ||
!ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) ||
Expand Down Expand Up @@ -1134,12 +1164,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
UniqueHloInstruction& epsilon) {
HloInstruction* gte = custom_call->users()[0];
if (Match(instr,
m::Divide(
m::Op(),
AddAnyOrder(
m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()),
m::Broadcast(m::ConstantScalar().WithPredicate(
epsilon.GetCaptureOrVerifyFn())))))) {
m::Divide(m::Op(),
AddAnyOrder(
m::Op().WithPredicate(variance.CaptureOrVerifyFn()),
m::Broadcast(m::ConstantScalar().WithPredicate(
epsilon.CaptureOrVerifyFn())))))) {
// Verify the uniqueness of the operands.
if (!variance.Instr() || !epsilon.Instr()) {
VLOG(1) << "Layer norm operands not unique.";
Expand Down
63 changes: 63 additions & 0 deletions xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,69 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
TestNorm(hlo_text, optimized_hlo);
}

TEST_F(CudnnNormRewriterTest, LayerNorm4D3TypeConversion) {
const char* hlo_text = R"(
HloModule test
apply {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT c = f32[] add(a,b)
}
ENTRY test {
input = f16[2,4,6,8] parameter(0)
input_f32 = f32[2,4,6,8] convert(input)
input_square = f32[2,4,6,8] multiply(input_f32, input_f32)
c0 = f32[] constant(0)
input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
r_nelems = f32[] constant(0.125)
r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply
input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
epsilon = f32[] constant(0.001)
epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast)
norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
scale = f16[8] parameter(1)
scale_f32 = f32[8] convert(scale)
scale_bcast = f32[2,4,6,8] broadcast(scale_f32), dimensions={3}
norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
bias = f16[8] parameter(2)
bias_f32 = f32[8] convert(bias)
bias_bcast = f32[2,4,6,8] broadcast(bias_f32), dimensions={3}
norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
ROOT out = f16[2,4,6,8] convert(norm_scale_bias)
})";

const char* optimized_hlo = R"(
; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f16[8], {{.*}}: f16[8]) -> f16[2,4,6,8] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,4,6,8]{3,2,1,0} parameter(0)
; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} bitcast([[P0]])
; CHECK-NEXT: [[P1:%[^ ]+]] = f16[8]{0} parameter(1)
; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P1]])
; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2)
; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P2]])
; CHECK-NEXT: [[CC:%[^ ]+]] = (f16[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
; CHECK: custom_call_target="__cudnn$norm",
; CHECK: backend_config={
; CHECK-DAG: "epsilon":0.001
; CHECK: }
; CHECK-NEXT: [[GTE:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
; CHECK-NEXT: ROOT {{.*}} = f16[2,4,6,8]{3,2,1,0} bitcast([[GTE]])
)";

TestNorm(hlo_text, optimized_hlo);
}

TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) {
const char* hlo_text = R"(
HloModule test
Expand Down

0 comments on commit 2880fde

Please sign in to comment.