Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add xpu support #396

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

add xpu support #396

wants to merge 8 commits into from

Conversation

mgrabban
Copy link
Collaborator

Summary

Adds xpu support so all tests, benchmarks etc. run on XPUs or Intel GPUs.

Details

infer_device() function is moved to a separate file and in any file where previously "cuda" was needed, infer_device is imported and "cuda" is replaced with return value of a call to infer_device()

Testing Done

A100 80GB PCIe, RTX 3060, Intel Data Center GPU Max 1550

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@lancerts lancerts requested review from ByronHsu, shimizust and lancerts and removed request for ByronHsu November 19, 2024 22:51
@faaany
Copy link
Collaborator

faaany commented Nov 20, 2024

I ran the UTs on XPU, but got "Segmentation fault (core dumped)" at one test, under investigation.

@mgrabban
Copy link
Collaborator Author

mgrabban commented Nov 20, 2024

I ran the UTs on XPU, but got "Segmentation fault (core dumped)" at one test, under investigation.

Which specific Intel GPU did you test on?
Also does the test run if you just change "cuda" to "xpu" manually without using this PR?

Also: I just added xpu support to simpo_loss (which was added later on and still had "cuda" hard coded).

@faaany
Copy link
Collaborator

faaany commented Nov 21, 2024

I ran the UTs on XPU, but got "Segmentation fault (core dumped)" at one test, under investigation.

Which specific Intel GPU did you test on? Also does the test run if you just change "cuda" to "xpu" manually without using this PR?

Also: I just added xpu support to simpo_loss (which was added later on and still had "cuda" hard coded).

I use "Intel Data Center GPU Max 1550".

And I tested your latest code. All tests pass except "pytest -rA test/transformers/test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512]", but this one is a known issue and got fixed in the latest pytorch-triton-xpu. Don't you have this issue?

@mgrabban
Copy link
Collaborator Author

And I tested your latest code. All tests pass except "pytest -rA test/transformers/test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512]", but this one is a known issue and got fixed in the latest pytorch-triton-xpu. Don't you have this issue?

I don't have this issue. It could be because I use nightly intel-xpu-backend-for-triton.

test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                           [  3%]
test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                           [  6%]
test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512] PASSED                                                                               [  9%]
test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123] PASSED                                                                               [ 12%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                           [ 15%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                           [ 18%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512] PASSED                                                                               [ 21%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123] PASSED                                                                               [ 25%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                             [ 28%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                             [ 31%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512] PASSED                                                                                 [ 34%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-5-123-123] PASSED                                                                                 [ 37%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                          [ 40%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                          [ 43%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512] PASSED                                                                              [ 46%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123] PASSED                                                                              [ 50%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                          [ 53%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                          [ 56%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512] PASSED                                                                              [ 59%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123] PASSED                                                                              [ 62%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                            [ 65%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                            [ 68%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512] PASSED                                                                                [ 71%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-5-123-123] PASSED                                                                                [ 75%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-2-8] PASSED                                                                         [ 78%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-9-7-41] PASSED                                                                        [ 81%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-2-8] PASSED                                                                             [ 84%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-9-7-41] PASSED                                                                            [ 87%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-2-8] PASSED                                                                         [ 90%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-9-7-41] PASSED                                                                        [ 93%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-2-8] PASSED                                                                             [ 96%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-9-7-41] PASSED                                                                            [100%]

@faaany
Copy link
Collaborator

faaany commented Nov 22, 2024

And I tested your latest code. All tests pass except "pytest -rA test/transformers/test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512]", but this one is a known issue and got fixed in the latest pytorch-triton-xpu. Don't you have this issue?

I don't have this issue. It could be because I use nightly intel-xpu-backend-for-triton.

test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                           [  3%]
test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                           [  6%]
test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512] PASSED                                                                               [  9%]
test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123] PASSED                                                                               [ 12%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                           [ 15%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                           [ 18%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512] PASSED                                                                               [ 21%]
test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123] PASSED                                                                               [ 25%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                             [ 28%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                             [ 31%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512] PASSED                                                                                 [ 34%]
test_rms_norm.py::test_correctness[True-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-5-123-123] PASSED                                                                                 [ 37%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                          [ 40%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                          [ 43%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512] PASSED                                                                              [ 46%]
test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123] PASSED                                                                              [ 50%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                          [ 53%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                          [ 56%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512] PASSED                                                                              [ 59%]
test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123] PASSED                                                                              [ 62%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-2-128-512] PASSED                                                                            [ 65%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype0-0.0001-1e-06-5-123-123] PASSED                                                                            [ 68%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-2-128-512] PASSED                                                                                [ 71%]
test_rms_norm.py::test_correctness[False-BaseRMSNorm-0.0-none-dtype1-0.2-0.02-5-123-123] PASSED                                                                                [ 75%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-2-8] PASSED                                                                         [ 78%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-9-7-41] PASSED                                                                        [ 81%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-2-8] PASSED                                                                             [ 84%]
test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-9-7-41] PASSED                                                                            [ 87%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-2-8] PASSED                                                                         [ 90%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-9-7-41] PASSED                                                                        [ 93%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-2-8] PASSED                                                                             [ 96%]
test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-9-7-41] PASSED                                                                            [100%]

Thanks for the update. This PR looks good to me.

@ByronHsu
Copy link
Collaborator

@faaany @mgrabban can you fix the conflict and we can merge this ASAP?

@ByronHsu
Copy link
Collaborator

Looks good to me! @mgrabban i just invited you as the collab of this repo, can you check the email? After acceptance, can you create a new branch in the main repo, and create a new PR based on that branch? Our CI has issues currently, so any PR from external folks cannot run CI. Thanks in advance!!

@mgrabban mgrabban mentioned this pull request Nov 22, 2024
3 tasks
@mgrabban
Copy link
Collaborator Author

Looks good to me! @mgrabban i just invited you as the collab of this repo, can you check the email? After acceptance, can you create a new branch in the main repo, and create a new PR based on that branch? Our CI has issues currently, so any PR from external folks cannot run CI. Thanks in advance!!

This is done now. See #407

ByronHsu pushed a commit that referenced this pull request Nov 23, 2024
## Summary
Replica of #396 
Adds xpu support so all tests, benchmarks etc. run on XPUs or Intel
GPUs.

## Details
infer_device() function is moved to a separate file and in any file
where previously "cuda" was needed, infer_device is imported and "cuda"
is replaced with return value of a call to infer_device()

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
A100 80GB PCIe, RTX 3060, Intel Data Center GPU Max 1550
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants