Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Norms Refractor #1140

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Norms Refractor #1140

wants to merge 1 commit into from

Conversation

phu0ngng
Copy link
Collaborator

Description

Layernorm and RMS refractor, preparing for TE/cuDNN Norms Integration.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested a review from ptrendx August 27, 2024 20:59
Signed-off-by: Phuong Nguyen <[email protected]>
@phu0ngng
Copy link
Collaborator Author

/te-ci pytorch

Comment on lines +7 to +8
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_NORMS_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_NORMS_H_


#include "../common.h"

namespace transformer_engine {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner if we wrapped these objects in a namespace:

Suggested change
namespace transformer_engine {
namespace transformer_engine {
namespace norms {

The names of these classes (LaunchParams, FwdParams, etc) are not very specific, so putting them in a namespace helps make clear that they are related to LayerNorm/RMSNorm.

z->data.dtype, // otype,
DType::kFloat32, // ctype,
params);
if (params.fp8_out) set_amax();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This zeroing might be unnecessary. The FP8 scale update kernel called by PyTorch also zeros out the amax.

That said, this matches the existing behavior so no need to change right now.


// Scaling factor
void* scale;
int scale_byte_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using this value?

Suggested change
int scale_byte_size;


// AMax output
void* amax;
int amax_byte_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the amax zeroing kernel, I don't think we are using this value:

Suggested change
int amax_byte_size;

If we want to keep an option to zero out the amax, then could we just assume the amax is FP32 like we do in the kernels?

Comment on lines +118 to +127
enum NVTE_NORM_TYPE {
LN_FWD_TE,
LN_BWD_TE,
LN_FWD_CUDNN,
LN_BWD_CUDNN,
RMS_FWD_TE,
RMS_BWD_TE,
RMS_FWD_CUDNN,
RMS_BWD_CUDNN,
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically this can be split into three enums:

Suggested change
enum NVTE_NORM_TYPE {
LN_FWD_TE,
LN_BWD_TE,
LN_FWD_CUDNN,
LN_BWD_CUDNN,
RMS_FWD_TE,
RMS_BWD_TE,
RMS_FWD_CUDNN,
RMS_BWD_CUDNN,
};
enum class NormType { LayerNorm, RMSNorm };
enum class NormStage { Forward, Backward };
enum class NormImpl { TE, CUDNN };

Comment on lines +6 to +7

/* #include <transformer_engine/layer_norm.h> */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/* #include <transformer_engine/layer_norm.h> */

Comment on lines +176 to +181
template <NVTE_NORM_TYPE NormEnum, bool IF_TUNED>
struct RegistryType<NormEnum, IF_TUNED, typename std::enable_if<IF_TE_NORMS<NormEnum>()>::type> {
using type = std::conditional_t<
IF_TUNED, std::conditional_t<IF_TE_FWD_NORMS<NormEnum>(), FwdTunedRegistry, BwdTunedRegistry>,
std::conditional_t<IF_TE_FWD_NORMS<NormEnum>(), FwdGeneralRegistry, BwdGeneralRegistry>>;
};
Copy link
Collaborator

@timmoon10 timmoon10 Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While templating on IF_TUNED makes sense for our current kernel implementation, I don't think it's a good idea going forward. The reason TE compilation is so slow (and is unusable on some systems) is because we statically compile so many tuned LayerNorm/RMSNorm kernels. If we choose to keep the TE tuned kernels, we should port them to NVRTC so we can compile them at runtime, similar to how we handle the transpose kernels. This would completely remove the need for this static registry for tuned kernels (we'll still need it for general kernels though).

template <NVTE_NORM_TYPE NormEnum>
class NormFwdTe : public NormBase {
public:
NormFwdTe();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that calling this constructor is considered an error:

template <NVTE_NORM_TYPE NormEnum>
NormFwdTe<NormEnum>::NormFwdTe() {
if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) {
NVTE_ERROR("NormFwdTe default constructor is only for its inherited classes!");
}
}

We can catch this at compile-time with:

Suggested change
NormFwdTe();
NormFwdTe() = delete;

template <NVTE_NORM_TYPE NormEnum>
FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const FwdParams& params) {
if constexpr (!IF_TE_FWD_NORMS<NormEnum>()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can catch these errors at compile-time with a static_assert:

Suggested change
if constexpr (!IF_TE_FWD_NORMS<NormEnum>()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!");
static_assert(IF_TE_FWD_NORMS<NormEnum>(), "Unexpected NVTE_NORM_TYPE");

I see this if constexpr (...) { NVTE_ERROR(...) } pattern used in several other places.

@timmoon10 timmoon10 self-requested a review August 30, 2024 22:29
@phu0ngng phu0ngng marked this pull request as draft September 4, 2024 23:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants