-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: RMSNorm #59
Conversation
8eca230
to
7c55aa9
Compare
7c55aa9
to
136e8e1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有几点关于 python module 的东西:
- 如果文件夹当 module,那么必须要有
__init__.py
- 如果不是多个文件,那么不需要文件夹,直接用文件当 module
- 无论什么情况,被 import 的入口都要有
__all__
此外,长期规划应该是 cpp 对接 diopi,python 层对接 python 层,那么我建议把 common 改叫 ops
from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward | ||
|
||
|
||
all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] | |
__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] |
|
||
auto extRmsNorm(const at::Tensor& input, | ||
auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, | ||
const at::Tensor& input, | ||
const OptionalIntArray& normalized_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个签名需要修改
那个头文件可以完全删除
) | ||
|
||
|
||
__all__ = ["mha", "rotary", RMSNorm, RMSNormWithNormalizedShape] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__all__ = ["mha", "rotary", RMSNorm, RMSNormWithNormalizedShape] | |
__all__ = ["mha", "rotary", "RMSNorm", "RMSNormWithNormalizedShape"] |
common 叫 ops 比较好吧 |
这个等后续看看,感觉common意思还更直观。 |
refactor rms norm op, and rotary_embeding and mha.