-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type Conversions in Layer Norm Fusion #17281
Closed
Closed
Changes from 1 commit
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,15 +113,19 @@ using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>; | |
// HloInstruction: | ||
// UniqueHloInstruction x; | ||
// bool m = Match( | ||
// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)), | ||
// m::Sin(m::Op().WithPredicate(x.capture_and_verify)))); | ||
// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.CaptureOrVerifyFn())), | ||
// m::Sin(m::Op().WithPredicate(x.CaptureOrVerifyFn())))); | ||
// m is true and x.Instr() returns an HloInstruction pointer to the operand of | ||
// cosine and sine iff HloInstruction *instr points to a division of a cosine by | ||
// a sine that operate on the same instruction. | ||
class UniqueHloInstruction { | ||
public: | ||
UniqueHloInstruction() | ||
: is_set_(false), instr_(nullptr), capture_or_verify_() {} | ||
: is_set_(false), | ||
instr_(nullptr), | ||
capture_or_verify_([this](const HloInstruction* instr) -> bool { | ||
return CaptureOrVerify(const_cast<HloInstruction*>(instr)); | ||
}) {} | ||
HloInstruction* Instr() const { return instr_; } | ||
void SetInstr(HloInstruction* instr) { | ||
is_set_ = true; | ||
|
@@ -143,12 +147,7 @@ class UniqueHloInstruction { | |
|
||
// Returns a std::function for capturing or verifying an instruction using | ||
// WithPredicate. | ||
std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() { | ||
if (!capture_or_verify_) { | ||
capture_or_verify_ = [this](const HloInstruction* instr) -> bool { | ||
return CaptureOrVerify(const_cast<HloInstruction*>(instr)); | ||
}; | ||
} | ||
std::function<bool(const HloInstruction*)> CaptureOrVerifyFn() const { | ||
return capture_or_verify_; | ||
} | ||
|
||
|
@@ -465,6 +464,16 @@ auto OptionalSupportedTransform(Pattern pattern) { | |
SupportedBitcastOrReshape(shared_subpattern), shared_subpattern); | ||
} | ||
|
||
// Broadcast with optional supported type conversion. | ||
template <typename Pattern> | ||
auto Broadcast(HloInstruction** bcast, Pattern pattern) { | ||
auto shared_subpattern = m::SharedSubpattern(pattern); | ||
return m::AnyOf<HloInstruction>( | ||
SupportedConvert(m::Broadcast(bcast, pattern)), | ||
m::Broadcast(bcast, SupportedConvert(pattern)), | ||
m::Broadcast(bcast, pattern)); | ||
} | ||
|
||
// Bitcast or reshape with optional supported type conversion and/or addition or | ||
// removal of degenerate dimensions. | ||
template <typename Pattern> | ||
|
@@ -597,7 +606,7 @@ auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) { | |
.WithPredicate([](const HloInstruction* instr) { | ||
return CalculatesExpectation(instr); | ||
}) | ||
.WithPredicate(expectation->GetCaptureOrVerifyFn())); | ||
.WithPredicate(expectation->CaptureOrVerifyFn())); | ||
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern), | ||
shared_subpattern); | ||
} | ||
|
@@ -612,7 +621,7 @@ auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce, | |
.WithPredicate([](const HloInstruction* instr) { | ||
return CalculatesExpectation(instr); | ||
}) | ||
.WithPredicate(expectation->GetCaptureOrVerifyFn())); | ||
.WithPredicate(expectation->CaptureOrVerifyFn())); | ||
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern), | ||
shared_subpattern); | ||
} | ||
|
@@ -624,19 +633,19 @@ auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation, | |
return m::AnyOf<HloInstruction>( | ||
Subtract( | ||
Expectation(Square(OptionalSupportedTransform( | ||
m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))), | ||
Square(Expectation(expectation, | ||
OptionalSupportedTransform(m::Op().WithPredicate( | ||
x->GetCaptureOrVerifyFn()))))) | ||
.WithPredicate(variance->GetCaptureOrVerifyFn()), | ||
m::Op().WithPredicate(x->CaptureOrVerifyFn())))), | ||
Square(Expectation( | ||
expectation, OptionalSupportedTransform( | ||
m::Op().WithPredicate(x->CaptureOrVerifyFn()))))) | ||
.WithPredicate(variance->CaptureOrVerifyFn()), | ||
Expectation( | ||
Square(Subtract( | ||
OptionalSupportedTransform( | ||
m::Op().WithPredicate(x->GetCaptureOrVerifyFn())), | ||
m::Op().WithPredicate(x->CaptureOrVerifyFn())), | ||
Expectation(expectation, | ||
OptionalSupportedTransform(m::Op().WithPredicate( | ||
x->GetCaptureOrVerifyFn())))))) | ||
.WithPredicate(variance->GetCaptureOrVerifyFn())); | ||
OptionalSupportedTransform( | ||
m::Op().WithPredicate(x->CaptureOrVerifyFn())))))) | ||
.WithPredicate(variance->CaptureOrVerifyFn())); | ||
} | ||
|
||
// Reciprocal of the square root of variance + epsilon with optional broadcast. | ||
|
@@ -647,7 +656,7 @@ auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x, | |
auto shared_subpattern = m::SharedSubpattern(Rsqrt( | ||
norm_factor, AddAnyOrder(Variance(variance, expectation, x), | ||
m::Broadcast(m::ConstantScalar().WithPredicate( | ||
epsilon->GetCaptureOrVerifyFn()))))); | ||
epsilon->CaptureOrVerifyFn()))))); | ||
return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern), | ||
shared_subpattern); | ||
} | ||
|
@@ -696,10 +705,10 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) { | |
|
||
// Expectation fused into a layer norm Custom Call. | ||
auto FusedExpectation(UniqueHloInstruction* custom_call) { | ||
auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement( | ||
m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->GetCaptureOrVerifyFn()), | ||
1)); | ||
auto shared_subpattern = m::SharedSubpattern( | ||
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->CaptureOrVerifyFn()), | ||
1)); | ||
return m::AnyOf<HloInstruction>(shared_subpattern, | ||
BitcastOrReshape(shared_subpattern)); | ||
} | ||
|
@@ -708,21 +717,20 @@ auto FusedExpectation(UniqueHloInstruction* custom_call) { | |
auto FusedExpectation(UniqueHloInstruction* fused_expectation, | ||
UniqueHloInstruction* custom_call) { | ||
auto shared_subpattern = m::SharedSubpattern( | ||
m::GetTupleElement( | ||
m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->GetCaptureOrVerifyFn()), | ||
1) | ||
.WithPredicate(fused_expectation->GetCaptureOrVerifyFn())); | ||
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->CaptureOrVerifyFn()), | ||
1) | ||
.WithPredicate(fused_expectation->CaptureOrVerifyFn())); | ||
return m::AnyOf<HloInstruction>(shared_subpattern, | ||
BitcastOrReshape(shared_subpattern)); | ||
} | ||
|
||
// Norm factor fused into a layer norm Custom Call. | ||
auto FusedNormFactor(UniqueHloInstruction* custom_call) { | ||
auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement( | ||
m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->GetCaptureOrVerifyFn()), | ||
2)); | ||
auto shared_subpattern = m::SharedSubpattern( | ||
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->CaptureOrVerifyFn()), | ||
2)); | ||
return m::AnyOf<HloInstruction>(shared_subpattern, | ||
BitcastOrReshape(shared_subpattern)); | ||
} | ||
|
@@ -731,11 +739,10 @@ auto FusedNormFactor(UniqueHloInstruction* custom_call) { | |
auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor, | ||
UniqueHloInstruction* custom_call) { | ||
auto shared_subpattern = m::SharedSubpattern( | ||
m::GetTupleElement( | ||
m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->GetCaptureOrVerifyFn()), | ||
2) | ||
.WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn())); | ||
m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) | ||
.WithPredicate(custom_call->CaptureOrVerifyFn()), | ||
2) | ||
.WithPredicate(fused_norm_factor->CaptureOrVerifyFn())); | ||
return m::AnyOf<HloInstruction>(shared_subpattern, | ||
BitcastOrReshape(shared_subpattern)); | ||
} | ||
|
@@ -784,7 +791,7 @@ auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x, | |
}; | ||
return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation, | ||
custom_call))) | ||
.WithPredicate(x_center->GetCaptureOrVerifyFn()) | ||
.WithPredicate(x_center->CaptureOrVerifyFn()) | ||
.WithPredicate(capture_or_verify_x); | ||
} | ||
|
||
|
@@ -806,7 +813,7 @@ auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, | |
reduce, MultiplyMultiplyAnyOrder( | ||
XCenter(x, custom_call, norm_metadata), | ||
m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)), | ||
m::Op().WithPredicate(dy->GetCaptureOrVerifyFn()))); | ||
m::Op().WithPredicate(dy->CaptureOrVerifyFn()))); | ||
} | ||
|
||
// Product of XCenter and the scaled and broadcasted product of F0 and | ||
|
@@ -872,7 +879,7 @@ auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale, | |
m::Broadcast( | ||
BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))), | ||
MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale), | ||
m::Op().WithPredicate(dy->GetCaptureOrVerifyFn()))); | ||
m::Op().WithPredicate(dy->CaptureOrVerifyFn()))); | ||
} | ||
|
||
class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { | ||
|
@@ -902,13 +909,13 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { | |
instr, | ||
SubtractMultiplyAddAnyOrder( | ||
OptionalSupportedTransform( | ||
m::Op().WithPredicate(x.GetCaptureOrVerifyFn())), | ||
m::Op().WithPredicate(x.CaptureOrVerifyFn())), | ||
Expectation(&expectation, &reduce, | ||
OptionalSupportedTransform(m::Op().WithPredicate( | ||
x.GetCaptureOrVerifyFn()))), | ||
OptionalSupportedTransform( | ||
m::Op().WithPredicate(x.CaptureOrVerifyFn()))), | ||
NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon), | ||
m::Broadcast(&broadcast_scale, m::Op(&scale)), | ||
m::Broadcast(&broadcast_bias, m::Op(&bias))))) { | ||
Broadcast(&broadcast_scale, m::Op(&scale)), | ||
Broadcast(&broadcast_bias, m::Op(&bias))))) { | ||
#if CUDNN_VERSION < 8905 | ||
// Layer norm kernels are available with cuDNN 8.9.5 and above. | ||
VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5."; | ||
|
@@ -949,7 +956,30 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { | |
} | ||
|
||
// Verify the element types. The element types of input and output and the | ||
// shapes of scale and bias must match. | ||
// shapes of scale and bias must match. If a conversion to the type of the | ||
// input is the only user of the output, set the output to the conversion. | ||
// Similarly, if one of the users of the scale/bias is a conversion to the | ||
// type of the bias/scale, set the scale/bias to the conversion. | ||
if (instr->user_count() == 1 && | ||
instr->users()[0]->opcode() == HloOpcode::kConvert && | ||
ShapeUtil::SameElementType(instr->users()[0]->shape(), | ||
x.Instr()->shape())) { | ||
instr = instr->users()[0]; | ||
} | ||
for (HloInstruction* scale_user : scale->users()) { | ||
if (scale_user->opcode() == HloOpcode::kConvert && | ||
ShapeUtil::SameElementType(scale_user->shape(), bias->shape())) { | ||
scale = scale_user; | ||
break; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why check if there is a convert user? This user might not even be part of the layer norm. |
||
for (HloInstruction* bias_user : bias->users()) { | ||
if (bias_user->opcode() == HloOpcode::kConvert && | ||
ShapeUtil::SameElementType(bias_user->shape(), scale->shape())) { | ||
bias = bias_user; | ||
break; | ||
} | ||
} | ||
if (!CompatibleElementType(instr) || !CompatibleElementType(scale) || | ||
!CompatibleElementType(bias) || | ||
!ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) || | ||
|
@@ -1134,12 +1164,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { | |
UniqueHloInstruction& epsilon) { | ||
HloInstruction* gte = custom_call->users()[0]; | ||
if (Match(instr, | ||
m::Divide( | ||
m::Op(), | ||
AddAnyOrder( | ||
m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()), | ||
m::Broadcast(m::ConstantScalar().WithPredicate( | ||
epsilon.GetCaptureOrVerifyFn())))))) { | ||
m::Divide(m::Op(), | ||
AddAnyOrder( | ||
m::Op().WithPredicate(variance.CaptureOrVerifyFn()), | ||
m::Broadcast(m::ConstantScalar().WithPredicate( | ||
epsilon.CaptureOrVerifyFn())))))) { | ||
// Verify the uniqueness of the operands. | ||
if (!variance.Instr() || !epsilon.Instr()) { | ||
VLOG(1) << "Layer norm operands not unique."; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am nervous about doing an incorrect rewrite based on these conversions you allow. What if the user, e.g., converts the input of layer norm from f32 to s2 and does all the layer norm logic in s2, before casting back to f32? It would be incorrect i think to rewrite this to a cudnn layer norm of full precision. Can we check at least all the types are floats of at least precision bf16?