-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Develop Embedding[SiliconFlow] #72
Conversation
why not register in src/flag_gems/init.py |
|
||
BLOCK_SIZE = triton.next_power_of_2(N) | ||
indices = indices.contiguous() | ||
weight = weight.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to serialize here first? because it may cause memory copy overhead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job for such complicated Op! Just a few comment here. I think we can replace torch.nn.functional.embedding
in src/flag_gems/__init__.py
so we can use it with flag_gems.enable()
or with flag_gems.use_gems():
; I wonder the performance of this op and whether we should do some related tuning work, so would you add a benchmark for this OP?
Also, There are some accuracy issues in the unit test maybe you can pull the master code try to solve. And better to add PR description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I suggest implementing performance test as well.
- Accuracy tests failed with relative difference equals to 1.0 . What about rebasing on the latest master branch and trying again?
We have verified locally that the problem of unit test accuracy can be solved by pulling the latest code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
测试环境:4090
测试case,Embedding=4096 * 4096 ,输入index大小为32
前向:
triton 耗时 2.82us, torch耗时3.78us
反向:
triton耗时 14.78us,torch耗时 3.78us
torch的反向机制是利用了 warp ballot,一个block读取所有index存到smem,如果有遇到相同index的,则逐出warp,只让一个warp来负责加,triton似乎不存在这种机制