Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Conversions in Layer Norm Fusion #17281

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 83 additions & 54 deletions xla/service/gpu/transforms/cudnn_norm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,19 @@ using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>;
// HloInstruction:
// UniqueHloInstruction x;
// bool m = Match(
// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)),
// m::Sin(m::Op().WithPredicate(x.capture_and_verify))));
// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.CaptureOrVerifyFn())),
// m::Sin(m::Op().WithPredicate(x.CaptureOrVerifyFn()))));
// m is true and x.Instr() returns an HloInstruction pointer to the operand of
// cosine and sine iff HloInstruction *instr points to a division of a cosine by
// a sine that operate on the same instruction.
class UniqueHloInstruction {
public:
UniqueHloInstruction()
: is_set_(false), instr_(nullptr), capture_or_verify_() {}
: is_set_(false),
instr_(nullptr),
capture_or_verify_([this](const HloInstruction* instr) -> bool {
return CaptureOrVerify(const_cast<HloInstruction*>(instr));
}) {}
HloInstruction* Instr() const { return instr_; }
void SetInstr(HloInstruction* instr) {
is_set_ = true;
Expand All @@ -143,12 +147,7 @@ class UniqueHloInstruction {

// Returns a std::function for capturing or verifying an instruction using
// WithPredicate.
std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() {
if (!capture_or_verify_) {
capture_or_verify_ = [this](const HloInstruction* instr) -> bool {
return CaptureOrVerify(const_cast<HloInstruction*>(instr));
};
}
std::function<bool(const HloInstruction*)> CaptureOrVerifyFn() const {
return capture_or_verify_;
}

Expand Down Expand Up @@ -465,6 +464,16 @@ auto OptionalSupportedTransform(Pattern pattern) {
SupportedBitcastOrReshape(shared_subpattern), shared_subpattern);
}

// Broadcast with optional supported type conversion.
template <typename Pattern>
auto Broadcast(HloInstruction** bcast, Pattern pattern) {
auto shared_subpattern = m::SharedSubpattern(pattern);
return m::AnyOf<HloInstruction>(
SupportedConvert(m::Broadcast(bcast, pattern)),
m::Broadcast(bcast, SupportedConvert(pattern)),
m::Broadcast(bcast, pattern));
}

// Bitcast or reshape with optional supported type conversion and/or addition or
// removal of degenerate dimensions.
template <typename Pattern>
Expand Down Expand Up @@ -597,7 +606,7 @@ auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) {
.WithPredicate([](const HloInstruction* instr) {
return CalculatesExpectation(instr);
})
.WithPredicate(expectation->GetCaptureOrVerifyFn()));
.WithPredicate(expectation->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
shared_subpattern);
}
Expand All @@ -612,7 +621,7 @@ auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce,
.WithPredicate([](const HloInstruction* instr) {
return CalculatesExpectation(instr);
})
.WithPredicate(expectation->GetCaptureOrVerifyFn()));
.WithPredicate(expectation->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
shared_subpattern);
}
Expand All @@ -624,19 +633,19 @@ auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation,
return m::AnyOf<HloInstruction>(
Subtract(
Expectation(Square(OptionalSupportedTransform(
m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))),
Square(Expectation(expectation,
OptionalSupportedTransform(m::Op().WithPredicate(
x->GetCaptureOrVerifyFn())))))
.WithPredicate(variance->GetCaptureOrVerifyFn()),
m::Op().WithPredicate(x->CaptureOrVerifyFn())))),
Square(Expectation(
expectation, OptionalSupportedTransform(
m::Op().WithPredicate(x->CaptureOrVerifyFn())))))
.WithPredicate(variance->CaptureOrVerifyFn()),
Expectation(
Square(Subtract(
OptionalSupportedTransform(
m::Op().WithPredicate(x->GetCaptureOrVerifyFn())),
m::Op().WithPredicate(x->CaptureOrVerifyFn())),
Expectation(expectation,
OptionalSupportedTransform(m::Op().WithPredicate(
x->GetCaptureOrVerifyFn()))))))
.WithPredicate(variance->GetCaptureOrVerifyFn()));
OptionalSupportedTransform(
m::Op().WithPredicate(x->CaptureOrVerifyFn()))))))
.WithPredicate(variance->CaptureOrVerifyFn()));
}

// Reciprocal of the square root of variance + epsilon with optional broadcast.
Expand All @@ -647,7 +656,7 @@ auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x,
auto shared_subpattern = m::SharedSubpattern(Rsqrt(
norm_factor, AddAnyOrder(Variance(variance, expectation, x),
m::Broadcast(m::ConstantScalar().WithPredicate(
epsilon->GetCaptureOrVerifyFn())))));
epsilon->CaptureOrVerifyFn())))));
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
shared_subpattern);
}
Expand Down Expand Up @@ -696,10 +705,10 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) {

// Expectation fused into a layer norm Custom Call.
auto FusedExpectation(UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
1));
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
1));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}
Expand All @@ -708,21 +717,20 @@ auto FusedExpectation(UniqueHloInstruction* custom_call) {
auto FusedExpectation(UniqueHloInstruction* fused_expectation,
UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
1)
.WithPredicate(fused_expectation->GetCaptureOrVerifyFn()));
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
1)
.WithPredicate(fused_expectation->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}

// Norm factor fused into a layer norm Custom Call.
auto FusedNormFactor(UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
2));
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
2));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}
Expand All @@ -731,11 +739,10 @@ auto FusedNormFactor(UniqueHloInstruction* custom_call) {
auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor,
UniqueHloInstruction* custom_call) {
auto shared_subpattern = m::SharedSubpattern(
m::GetTupleElement(
m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->GetCaptureOrVerifyFn()),
2)
.WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn()));
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget})
.WithPredicate(custom_call->CaptureOrVerifyFn()),
2)
.WithPredicate(fused_norm_factor->CaptureOrVerifyFn()));
return m::AnyOf<HloInstruction>(shared_subpattern,
BitcastOrReshape(shared_subpattern));
}
Expand Down Expand Up @@ -784,7 +791,7 @@ auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x,
};
return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation,
custom_call)))
.WithPredicate(x_center->GetCaptureOrVerifyFn())
.WithPredicate(x_center->CaptureOrVerifyFn())
.WithPredicate(capture_or_verify_x);
}

Expand All @@ -806,7 +813,7 @@ auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
reduce, MultiplyMultiplyAnyOrder(
XCenter(x, custom_call, norm_metadata),
m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)),
m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
m::Op().WithPredicate(dy->CaptureOrVerifyFn())));
}

// Product of XCenter and the scaled and broadcasted product of F0 and
Expand Down Expand Up @@ -872,7 +879,7 @@ auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale,
m::Broadcast(
BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))),
MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale),
m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
m::Op().WithPredicate(dy->CaptureOrVerifyFn())));
}

class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
Expand Down Expand Up @@ -902,13 +909,13 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
instr,
SubtractMultiplyAddAnyOrder(
OptionalSupportedTransform(
m::Op().WithPredicate(x.GetCaptureOrVerifyFn())),
m::Op().WithPredicate(x.CaptureOrVerifyFn())),
Expectation(&expectation, &reduce,
OptionalSupportedTransform(m::Op().WithPredicate(
x.GetCaptureOrVerifyFn()))),
OptionalSupportedTransform(
m::Op().WithPredicate(x.CaptureOrVerifyFn()))),
NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon),
m::Broadcast(&broadcast_scale, m::Op(&scale)),
m::Broadcast(&broadcast_bias, m::Op(&bias))))) {
Broadcast(&broadcast_scale, m::Op(&scale)),
Broadcast(&broadcast_bias, m::Op(&bias))))) {
#if CUDNN_VERSION < 8905
// Layer norm kernels are available with cuDNN 8.9.5 and above.
VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5.";
Expand Down Expand Up @@ -949,7 +956,30 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
}

// Verify the element types. The element types of input and output and the
// shapes of scale and bias must match.
// shapes of scale and bias must match. If a conversion to the type of the
// input is the only user of the output, set the output to the conversion.
// Similarly, if one of the users of the scale/bias is a conversion to the
// type of the bias/scale, set the scale/bias to the conversion.
if (instr->user_count() == 1 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am nervous about doing an incorrect rewrite based on these conversions you allow. What if the user, e.g., converts the input of layer norm from f32 to s2 and does all the layer norm logic in s2, before casting back to f32? It would be incorrect i think to rewrite this to a cudnn layer norm of full precision. Can we check at least all the types are floats of at least precision bf16?

instr->users()[0]->opcode() == HloOpcode::kConvert &&
ShapeUtil::SameElementType(instr->users()[0]->shape(),
x.Instr()->shape())) {
instr = instr->users()[0];
}
for (HloInstruction* scale_user : scale->users()) {
if (scale_user->opcode() == HloOpcode::kConvert &&
ShapeUtil::SameElementType(scale_user->shape(), bias->shape())) {
scale = scale_user;
break;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check if there is a convert user? This user might not even be part of the layer norm.

for (HloInstruction* bias_user : bias->users()) {
if (bias_user->opcode() == HloOpcode::kConvert &&
ShapeUtil::SameElementType(bias_user->shape(), scale->shape())) {
bias = bias_user;
break;
}
}
if (!CompatibleElementType(instr) || !CompatibleElementType(scale) ||
!CompatibleElementType(bias) ||
!ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) ||
Expand Down Expand Up @@ -1134,12 +1164,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
UniqueHloInstruction& epsilon) {
HloInstruction* gte = custom_call->users()[0];
if (Match(instr,
m::Divide(
m::Op(),
AddAnyOrder(
m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()),
m::Broadcast(m::ConstantScalar().WithPredicate(
epsilon.GetCaptureOrVerifyFn())))))) {
m::Divide(m::Op(),
AddAnyOrder(
m::Op().WithPredicate(variance.CaptureOrVerifyFn()),
m::Broadcast(m::ConstantScalar().WithPredicate(
epsilon.CaptureOrVerifyFn())))))) {
// Verify the uniqueness of the operands.
if (!variance.Instr() || !epsilon.Instr()) {
VLOG(1) << "Layer norm operands not unique.";
Expand Down
63 changes: 63 additions & 0 deletions xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,69 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
TestNorm(hlo_text, optimized_hlo);
}

TEST_F(CudnnNormRewriterTest, LayerNorm4D3TypeConversion) {
const char* hlo_text = R"(
HloModule test

apply {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT c = f32[] add(a,b)
}

ENTRY test {
input = f16[2,4,6,8] parameter(0)
input_f32 = f32[2,4,6,8] convert(input)
input_square = f32[2,4,6,8] multiply(input_f32, input_f32)
c0 = f32[] constant(0)
input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
r_nelems = f32[] constant(0.125)
r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply
input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
epsilon = f32[] constant(0.001)
epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast)
norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
scale = f16[8] parameter(1)
scale_f32 = f32[8] convert(scale)
scale_bcast = f32[2,4,6,8] broadcast(scale_f32), dimensions={3}
norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
bias = f16[8] parameter(2)
bias_f32 = f32[8] convert(bias)
bias_bcast = f32[2,4,6,8] broadcast(bias_f32), dimensions={3}
norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
ROOT out = f16[2,4,6,8] convert(norm_scale_bias)
})";

const char* optimized_hlo = R"(

; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f16[8], {{.*}}: f16[8]) -> f16[2,4,6,8] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,4,6,8]{3,2,1,0} parameter(0)
; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} bitcast([[P0]])
; CHECK-NEXT: [[P1:%[^ ]+]] = f16[8]{0} parameter(1)
; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P1]])
; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2)
; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P2]])
; CHECK-NEXT: [[CC:%[^ ]+]] = (f16[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
; CHECK: custom_call_target="__cudnn$norm",
; CHECK: backend_config={
; CHECK-DAG: "epsilon":0.001
; CHECK: }
; CHECK-NEXT: [[GTE:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
; CHECK-NEXT: ROOT {{.*}} = f16[2,4,6,8]{3,2,1,0} bitcast([[GTE]])
)";

TestNorm(hlo_text, optimized_hlo);
}

TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) {
const char* hlo_text = R"(
HloModule test
Expand Down
Loading