Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: RMSNorm #59

Merged
merged 43 commits into from
Apr 1, 2024
Merged

refactor: RMSNorm #59

merged 43 commits into from
Apr 1, 2024

Conversation

zhangzefeng92
Copy link
Contributor

@zhangzefeng92 zhangzefeng92 commented Mar 26, 2024

refactor rms norm op, and rotary_embeding and mha.

@zhangzefeng92 zhangzefeng92 changed the title zzf/update cpp ext of rms norm refactor rms norm -zf Mar 27, 2024
@zhangzefeng92 zhangzefeng92 changed the title refactor rms norm -zf refactor-rms_norm Mar 27, 2024
@zhangzefeng92 zhangzefeng92 changed the title refactor-rms_norm refactor:rms_norm Mar 27, 2024
@zhangzefeng92 zhangzefeng92 changed the title refactor:rms_norm refactor:RMSNorm Mar 27, 2024
@zhangzefeng92 zhangzefeng92 changed the title refactor:RMSNorm refactor: RMSNorm Mar 27, 2024
@zhangzefeng92 zhangzefeng92 force-pushed the zzf/fix_rmsnorm branch 6 times, most recently from 8eca230 to 7c55aa9 Compare March 28, 2024 08:07
@lljbash lljbash requested a review from POI-WX March 29, 2024 02:42
Copy link
Collaborator

@lljbash lljbash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有几点关于 python module 的东西:

  1. 如果文件夹当 module,那么必须要有 __init__.py
  2. 如果不是多个文件,那么不需要文件夹,直接用文件当 module
  3. 无论什么情况,被 import 的入口都要有 __all__

此外,长期规划应该是 cpp 对接 diopi,python 层对接 python 层,那么我建议把 common 改叫 ops

from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward


all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"]
__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"]


auto extRmsNorm(const at::Tensor& input,
auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个签名需要修改

那个头文件可以完全删除

)


__all__ = ["mha", "rotary", RMSNorm, RMSNormWithNormalizedShape]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__all__ = ["mha", "rotary", RMSNorm, RMSNormWithNormalizedShape]
__all__ = ["mha", "rotary", "RMSNorm", "RMSNormWithNormalizedShape"]

@lljbash
Copy link
Collaborator

lljbash commented Apr 1, 2024

common 叫 ops 比较好吧

@yangbofun
Copy link
Collaborator

common 叫 ops 比较好吧

这个等后续看看,感觉common意思还更直观。

@yangbofun yangbofun merged commit 8261278 into main Apr 1, 2024
5 checks passed
@yangbofun yangbofun deleted the zzf/fix_rmsnorm branch April 1, 2024 08:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants