diff --git a/xla/service/gpu/transforms/cudnn_norm_rewriter.cc b/xla/service/gpu/transforms/cudnn_norm_rewriter.cc index 5d5e089933fd8..c1a1542b5cd70 100644 --- a/xla/service/gpu/transforms/cudnn_norm_rewriter.cc +++ b/xla/service/gpu/transforms/cudnn_norm_rewriter.cc @@ -113,15 +113,19 @@ using NormMetadataMap = absl::flat_hash_map; // HloInstruction: // UniqueHloInstruction x; // bool m = Match( -// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)), -// m::Sin(m::Op().WithPredicate(x.capture_and_verify)))); +// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.CaptureOrVerifyFn())), +// m::Sin(m::Op().WithPredicate(x.CaptureOrVerifyFn())))); // m is true and x.Instr() returns an HloInstruction pointer to the operand of // cosine and sine iff HloInstruction *instr points to a division of a cosine by // a sine that operate on the same instruction. class UniqueHloInstruction { public: UniqueHloInstruction() - : is_set_(false), instr_(nullptr), capture_or_verify_() {} + : is_set_(false), + instr_(nullptr), + capture_or_verify_([this](const HloInstruction* instr) -> bool { + return CaptureOrVerify(const_cast(instr)); + }) {} HloInstruction* Instr() const { return instr_; } void SetInstr(HloInstruction* instr) { is_set_ = true; @@ -143,12 +147,7 @@ class UniqueHloInstruction { // Returns a std::function for capturing or verifying an instruction using // WithPredicate. - std::function GetCaptureOrVerifyFn() { - if (!capture_or_verify_) { - capture_or_verify_ = [this](const HloInstruction* instr) -> bool { - return CaptureOrVerify(const_cast(instr)); - }; - } + std::function CaptureOrVerifyFn() const { return capture_or_verify_; } @@ -465,6 +464,16 @@ auto OptionalSupportedTransform(Pattern pattern) { SupportedBitcastOrReshape(shared_subpattern), shared_subpattern); } +// Broadcast with optional supported type conversion. +template +auto Broadcast(HloInstruction** bcast, Pattern pattern) { + auto shared_subpattern = m::SharedSubpattern(pattern); + return m::AnyOf( + SupportedConvert(m::Broadcast(bcast, pattern)), + m::Broadcast(bcast, SupportedConvert(pattern)), + m::Broadcast(bcast, pattern)); +} + // Bitcast or reshape with optional supported type conversion and/or addition or // removal of degenerate dimensions. template @@ -597,7 +606,7 @@ auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) { .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); }) - .WithPredicate(expectation->GetCaptureOrVerifyFn())); + .WithPredicate(expectation->CaptureOrVerifyFn())); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -612,7 +621,7 @@ auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce, .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); }) - .WithPredicate(expectation->GetCaptureOrVerifyFn())); + .WithPredicate(expectation->CaptureOrVerifyFn())); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -624,19 +633,19 @@ auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation, return m::AnyOf( Subtract( Expectation(Square(OptionalSupportedTransform( - m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))), - Square(Expectation(expectation, - OptionalSupportedTransform(m::Op().WithPredicate( - x->GetCaptureOrVerifyFn()))))) - .WithPredicate(variance->GetCaptureOrVerifyFn()), + m::Op().WithPredicate(x->CaptureOrVerifyFn())))), + Square(Expectation( + expectation, OptionalSupportedTransform( + m::Op().WithPredicate(x->CaptureOrVerifyFn()))))) + .WithPredicate(variance->CaptureOrVerifyFn()), Expectation( Square(Subtract( OptionalSupportedTransform( - m::Op().WithPredicate(x->GetCaptureOrVerifyFn())), + m::Op().WithPredicate(x->CaptureOrVerifyFn())), Expectation(expectation, - OptionalSupportedTransform(m::Op().WithPredicate( - x->GetCaptureOrVerifyFn())))))) - .WithPredicate(variance->GetCaptureOrVerifyFn())); + OptionalSupportedTransform( + m::Op().WithPredicate(x->CaptureOrVerifyFn())))))) + .WithPredicate(variance->CaptureOrVerifyFn())); } // Reciprocal of the square root of variance + epsilon with optional broadcast. @@ -647,7 +656,7 @@ auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x, auto shared_subpattern = m::SharedSubpattern(Rsqrt( norm_factor, AddAnyOrder(Variance(variance, expectation, x), m::Broadcast(m::ConstantScalar().WithPredicate( - epsilon->GetCaptureOrVerifyFn()))))); + epsilon->CaptureOrVerifyFn()))))); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -696,10 +705,10 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) { // Expectation fused into a layer norm Custom Call. auto FusedExpectation(UniqueHloInstruction* custom_call) { - auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 1)); + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 1)); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } @@ -708,21 +717,20 @@ auto FusedExpectation(UniqueHloInstruction* custom_call) { auto FusedExpectation(UniqueHloInstruction* fused_expectation, UniqueHloInstruction* custom_call) { auto shared_subpattern = m::SharedSubpattern( - m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 1) - .WithPredicate(fused_expectation->GetCaptureOrVerifyFn())); + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 1) + .WithPredicate(fused_expectation->CaptureOrVerifyFn())); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } // Norm factor fused into a layer norm Custom Call. auto FusedNormFactor(UniqueHloInstruction* custom_call) { - auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 2)); + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 2)); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } @@ -731,11 +739,10 @@ auto FusedNormFactor(UniqueHloInstruction* custom_call) { auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* custom_call) { auto shared_subpattern = m::SharedSubpattern( - m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 2) - .WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn())); + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 2) + .WithPredicate(fused_norm_factor->CaptureOrVerifyFn())); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } @@ -784,7 +791,7 @@ auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x, }; return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation, custom_call))) - .WithPredicate(x_center->GetCaptureOrVerifyFn()) + .WithPredicate(x_center->CaptureOrVerifyFn()) .WithPredicate(capture_or_verify_x); } @@ -806,7 +813,7 @@ auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, reduce, MultiplyMultiplyAnyOrder( XCenter(x, custom_call, norm_metadata), m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)), - m::Op().WithPredicate(dy->GetCaptureOrVerifyFn()))); + m::Op().WithPredicate(dy->CaptureOrVerifyFn()))); } // Product of XCenter and the scaled and broadcasted product of F0 and @@ -872,7 +879,7 @@ auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale, m::Broadcast( BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))), MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale), - m::Op().WithPredicate(dy->GetCaptureOrVerifyFn()))); + m::Op().WithPredicate(dy->CaptureOrVerifyFn()))); } class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { @@ -902,13 +909,13 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { instr, SubtractMultiplyAddAnyOrder( OptionalSupportedTransform( - m::Op().WithPredicate(x.GetCaptureOrVerifyFn())), + m::Op().WithPredicate(x.CaptureOrVerifyFn())), Expectation(&expectation, &reduce, - OptionalSupportedTransform(m::Op().WithPredicate( - x.GetCaptureOrVerifyFn()))), + OptionalSupportedTransform( + m::Op().WithPredicate(x.CaptureOrVerifyFn()))), NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon), - m::Broadcast(&broadcast_scale, m::Op(&scale)), - m::Broadcast(&broadcast_bias, m::Op(&bias))))) { + Broadcast(&broadcast_scale, m::Op(&scale)), + Broadcast(&broadcast_bias, m::Op(&bias))))) { #if CUDNN_VERSION < 8905 // Layer norm kernels are available with cuDNN 8.9.5 and above. VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5."; @@ -949,7 +956,30 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } // Verify the element types. The element types of input and output and the - // shapes of scale and bias must match. + // shapes of scale and bias must match. If a conversion to the type of the + // input is the only user of the output, set the output to the conversion. + // Similarly, if one of the users of the scale/bias is a conversion to the + // type of the bias/scale, set the scale/bias to the conversion. + if (instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(instr->users()[0]->shape(), + x.Instr()->shape())) { + instr = instr->users()[0]; + } + for (HloInstruction* scale_user : scale->users()) { + if (scale_user->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(scale_user->shape(), bias->shape())) { + scale = scale_user; + break; + } + } + for (HloInstruction* bias_user : bias->users()) { + if (bias_user->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(bias_user->shape(), scale->shape())) { + bias = bias_user; + break; + } + } if (!CompatibleElementType(instr) || !CompatibleElementType(scale) || !CompatibleElementType(bias) || !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) || @@ -1134,12 +1164,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { UniqueHloInstruction& epsilon) { HloInstruction* gte = custom_call->users()[0]; if (Match(instr, - m::Divide( - m::Op(), - AddAnyOrder( - m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()), - m::Broadcast(m::ConstantScalar().WithPredicate( - epsilon.GetCaptureOrVerifyFn())))))) { + m::Divide(m::Op(), + AddAnyOrder( + m::Op().WithPredicate(variance.CaptureOrVerifyFn()), + m::Broadcast(m::ConstantScalar().WithPredicate( + epsilon.CaptureOrVerifyFn())))))) { // Verify the uniqueness of the operands. if (!variance.Instr() || !epsilon.Instr()) { VLOG(1) << "Layer norm operands not unique."; diff --git a/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc b/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc index a3dbc71132949..b21bbe0e0c197 100644 --- a/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc +++ b/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc @@ -535,6 +535,69 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNorm4D3TypeConversion) { + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f16[2,4,6,8] parameter(0) + input_f32 = f32[2,4,6,8] convert(input) + input_square = f32[2,4,6,8] multiply(input_f32, input_f32) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply + r_nelems = f32[] constant(0.125) + r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast) + input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply + input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast) + input_mean_square = f32[2,4,6] multiply(input_mean, input_mean) + variance = f32[2,4,6] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast) + norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2} + input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast) + norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center) + scale = f16[8] parameter(1) + scale_f32 = f32[8] convert(scale) + scale_bcast = f32[2,4,6,8] broadcast(scale_f32), dimensions={3} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f16[8] parameter(2) + bias_f32 = f32[8] convert(bias) + bias_bcast = f32[2,4,6,8] broadcast(bias_f32), dimensions={3} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + ROOT out = f16[2,4,6,8] convert(norm_scale_bias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f16[8], {{.*}}: f16[8]) -> f16[2,4,6,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[8]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC:%[^ ]+]] = (f16[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK: } +; CHECK-NEXT: [[GTE:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 +; CHECK-NEXT: ROOT {{.*}} = f16[2,4,6,8]{3,2,1,0} bitcast([[GTE]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) { const char* hlo_text = R"( HloModule test