Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop Embedding[SiliconFlow] #72

Merged
merged 13 commits into from
Jul 18, 2024
Merged

Develop Embedding[SiliconFlow] #72

merged 13 commits into from
Jul 18, 2024

Conversation

MARD1NO
Copy link
Collaborator

@MARD1NO MARD1NO commented Jun 17, 2024

测试环境:4090

测试case,Embedding=4096 * 4096 ,输入index大小为32

前向:
triton 耗时 2.82us, torch耗时3.78us
image

反向:
triton耗时 14.78us,torch耗时 3.78us
image
image

torch的反向机制是利用了 warp ballot,一个block读取所有index存到smem,如果有遇到相同index的,则逐出warp,只让一个warp来负责加,triton似乎不存在这种机制

@MARD1NO MARD1NO marked this pull request as ready for review July 2, 2024 07:10
@MARD1NO MARD1NO requested a review from StrongSpoon July 3, 2024 07:31
@StrongSpoon
Copy link
Collaborator

why not register in src/flag_gems/init.py


BLOCK_SIZE = triton.next_power_of_2(N)
indices = indices.contiguous()
weight = weight.contiguous()
Copy link
Collaborator

@Bowen12992 Bowen12992 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to serialize here first? because it may cause memory copy overhead

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job for such complicated Op! Just a few comment here. I think we can replace torch.nn.functional.embedding in src/flag_gems/__init__.py so we can use it with flag_gems.enable() or with flag_gems.use_gems(): ; I wonder the performance of this op and whether we should do some related tuning work, so would you add a benchmark for this OP?
Also, There are some accuracy issues in the unit test maybe you can pull the master code try to solve. And better to add PR description

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I suggest implementing performance test as well.
  2. Accuracy tests failed with relative difference equals to 1.0 . What about rebasing on the latest master branch and trying again?

src/flag_gems/ops/embedding.py Outdated Show resolved Hide resolved
tests/test_special_ops.py Outdated Show resolved Hide resolved
tests/test_special_ops.py Outdated Show resolved Hide resolved
src/flag_gems/ops/embedding.py Outdated Show resolved Hide resolved
src/flag_gems/ops/embedding.py Outdated Show resolved Hide resolved
src/flag_gems/ops/embedding.py Outdated Show resolved Hide resolved
src/flag_gems/ops/embedding.py Show resolved Hide resolved
src/flag_gems/ops/embedding.py Outdated Show resolved Hide resolved
src/flag_gems/ops/embedding.py Outdated Show resolved Hide resolved
@Bowen12992
Copy link
Collaborator

  1. I suggest implementing performance test as well.
  2. Accuracy tests failed with relative difference equals to 1.0 . What about rebasing on the latest master branch and trying again?

We have verified locally that the problem of unit test accuracy can be solved by pulling the latest code

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Bowen12992 Bowen12992 merged commit 7156a8f into master Jul 18, 2024
3 checks passed
@StrongSpoon StrongSpoon deleted the dev_embedding branch August 13, 2024 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants