Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce runtime dependency on torch #205

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

stephen-huan
Copy link

Fixes #204. It is now possible to import triton, execute a kernel, and autotune, all without torch.

Relatively straightforward translation of torch.tensor -> np.array, torch.float -> np.float32, np.quantile -> torch.quantile (note the default method for both is linear), and so on. Luckily numpy has np.min, np.max, np.mean, np.median, and np.all with the same semantics as the respective torch functions, so the same getattr(np, return_mode)(times) trick can be used.

I was getting

ImportError: ~/.triton/cache/TMED6KFKLZGWOEVFFCXVTCPBLE/__triton_cpu_launcher.so: undefined symbol: omp_get_thread_num

which was fixed either by importing torch or by adding omp to libraries in the cpu driver,

libraries = ["stdc++"]

but I can't seem to replicate this anymore.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because runtime change.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Copy link
Collaborator

@minjang minjang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I think it looks good.

As you know TritonCPU is forked from Triton. And we periodically rebase onto the latest Triton. We always see many merge conflicts. While this PR is a small change, it's still highly likely to cause future merge conflicts.

So, can you also make a PR in the upstream?

@ienkovich
Copy link
Collaborator

I agree that this change is not specific to the Triton CPU backend and should go through the original repo.

@stephen-huan
Copy link
Author

So, can you also make a PR in the upstream?

Opened triton-lang#5490.

The 6 test failures are due to accidentally keeping the device parameter which works in numpy 2.0.0 or later which is what I was testing on. This has been fixed and now this PR only includes the cpu backend specific changes.

@stephen-huan
Copy link
Author

Arg, there's now a synchronization issue with upstream. Hopefully it's ok to make these modifications to get it to work.

Sorry about the back-and-forth, third time's the charm. Haven't been running the tests due to some (unrelated) segfaults.

============================================================================ test session starts =============================================================================
platform linux -- Python 3.12.8, pytest-8.3.4, pluggy-1.5.0
...
collected 25631 items / 25624 deselected / 7 selected

python/test/unit/runtime/test_autotuner.py::test_kwargs[False] PASSED                                                                                                  [ 14%]
python/test/unit/runtime/test_autotuner.py::test_kwargs[True] SKIPPED (Use cuda graph without cuda looks strange)                                                      [ 28%]
python/test/unit/runtime/test_autotuner.py::test_restore[False] PASSED                                                                                                 [ 42%]
python/test/unit/runtime/test_autotuner.py::test_restore[True] PASSED                                                                                                  [ 57%]
python/test/unit/runtime/test_autotuner.py::test_hooks PASSED                                                                                                          [ 71%]
python/test/unit/runtime/test_autotuner.py::test_prune_configs[False] PASSED                                                                                           [ 85%]
python/test/unit/runtime/test_autotuner.py::test_prune_configs[True] PASSED                                                                                            [100%]

...
======================================================== 6 passed, 1 skipped, 25624 deselected, 5 warnings in 11.74s =========================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Torch dependency for importing triton, kernel execution and autotuning
3 participants