-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Norms Refractor #1140
base: main
Are you sure you want to change the base?
Norms Refractor #1140
Conversation
Signed-off-by: Phuong Nguyen <[email protected]>
/te-ci pytorch |
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ | ||
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ | |
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ | |
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_NORMS_H_ | |
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_NORMS_H_ |
|
||
#include "../common.h" | ||
|
||
namespace transformer_engine { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be cleaner if we wrapped these objects in a namespace:
namespace transformer_engine { | |
namespace transformer_engine { | |
namespace norms { |
The names of these classes (LaunchParams
, FwdParams
, etc) are not very specific, so putting them in a namespace helps make clear that they are related to LayerNorm/RMSNorm.
z->data.dtype, // otype, | ||
DType::kFloat32, // ctype, | ||
params); | ||
if (params.fp8_out) set_amax(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This zeroing might be unnecessary. The FP8 scale update kernel called by PyTorch also zeros out the amax.
That said, this matches the existing behavior so no need to change right now.
|
||
// Scaling factor | ||
void* scale; | ||
int scale_byte_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we using this value?
int scale_byte_size; |
|
||
// AMax output | ||
void* amax; | ||
int amax_byte_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove the amax zeroing kernel, I don't think we are using this value:
int amax_byte_size; |
If we want to keep an option to zero out the amax, then could we just assume the amax is FP32 like we do in the kernels?
enum NVTE_NORM_TYPE { | ||
LN_FWD_TE, | ||
LN_BWD_TE, | ||
LN_FWD_CUDNN, | ||
LN_BWD_CUDNN, | ||
RMS_FWD_TE, | ||
RMS_BWD_TE, | ||
RMS_FWD_CUDNN, | ||
RMS_BWD_CUDNN, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logically this can be split into three enums:
enum NVTE_NORM_TYPE { | |
LN_FWD_TE, | |
LN_BWD_TE, | |
LN_FWD_CUDNN, | |
LN_BWD_CUDNN, | |
RMS_FWD_TE, | |
RMS_BWD_TE, | |
RMS_FWD_CUDNN, | |
RMS_BWD_CUDNN, | |
}; | |
enum class NormType { LayerNorm, RMSNorm }; | |
enum class NormStage { Forward, Backward }; | |
enum class NormImpl { TE, CUDNN }; |
|
||
/* #include <transformer_engine/layer_norm.h> */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/* #include <transformer_engine/layer_norm.h> */ |
template <NVTE_NORM_TYPE NormEnum, bool IF_TUNED> | ||
struct RegistryType<NormEnum, IF_TUNED, typename std::enable_if<IF_TE_NORMS<NormEnum>()>::type> { | ||
using type = std::conditional_t< | ||
IF_TUNED, std::conditional_t<IF_TE_FWD_NORMS<NormEnum>(), FwdTunedRegistry, BwdTunedRegistry>, | ||
std::conditional_t<IF_TE_FWD_NORMS<NormEnum>(), FwdGeneralRegistry, BwdGeneralRegistry>>; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While templating on IF_TUNED
makes sense for our current kernel implementation, I don't think it's a good idea going forward. The reason TE compilation is so slow (and is unusable on some systems) is because we statically compile so many tuned LayerNorm/RMSNorm kernels. If we choose to keep the TE tuned kernels, we should port them to NVRTC so we can compile them at runtime, similar to how we handle the transpose kernels. This would completely remove the need for this static registry for tuned kernels (we'll still need it for general kernels though).
template <NVTE_NORM_TYPE NormEnum> | ||
class NormFwdTe : public NormBase { | ||
public: | ||
NormFwdTe(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that calling this constructor is considered an error:
TransformerEngine/transformer_engine/common/layer_norm/norms.cpp
Lines 147 to 152 in 41273eb
template <NVTE_NORM_TYPE NormEnum> | |
NormFwdTe<NormEnum>::NormFwdTe() { | |
if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) { | |
NVTE_ERROR("NormFwdTe default constructor is only for its inherited classes!"); | |
} | |
} |
We can catch this at compile-time with:
NormFwdTe(); | |
NormFwdTe() = delete; |
template <NVTE_NORM_TYPE NormEnum> | ||
FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, | ||
const FwdParams& params) { | ||
if constexpr (!IF_TE_FWD_NORMS<NormEnum>()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can catch these errors at compile-time with a static_assert
:
if constexpr (!IF_TE_FWD_NORMS<NormEnum>()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); | |
static_assert(IF_TE_FWD_NORMS<NormEnum>(), "Unexpected NVTE_NORM_TYPE"); |
I see this if constexpr (...) { NVTE_ERROR(...) }
pattern used in several other places.
Description
Layernorm and RMS refractor, preparing for TE/cuDNN Norms Integration.
Type of change
Changes
Checklist: