-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow using local omp with Apple clang #181
Conversation
Tested on my M1 Mac as, ``` OMP_NUM_THREADS=8 \ TRITON_LOCAL_LIBOMP_PATH="<path..to>/site-packages/torch/" \ CC=$(which clang) \ TRITON_CPU_BACKEND=1 \ $(which python3) \ python/tutorials/02-fused-softmax-cpu.py ```
f559fdc
to
69a5331
Compare
else: | ||
cc_cmd += ["-fopenmp"] | ||
if libomp_path: | ||
print("Info: Ignoring TRITON_LOCAL_LIBOMP_PATH for non-Apple clang compiler") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: print(..., file=sys.stderr)
as a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point. I missed this. I will sneak this in with my next PR.
cc_cmd += [f"-L{libomp_path}/lib"] | ||
cc_cmd += ["-lomp"] | ||
else: | ||
print("Warning: TRITON_LOCAL_LIBOMP_PATH is not set for Apple clang. OpenMP is disabled.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens in this case? Can triton still compile the kernel and run it, but with a single core?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes because we don't add -fopenmp
on cc_cmd
Tested on my M1 Mac as, ``` OMP_NUM_THREADS=8 \ TRITON_LOCAL_LIBOMP_PATH="<path..to>/site-packages/torch/" \ CC=$(which clang) \ TRITON_CPU_BACKEND=1 \ $(which python3) \ python/tutorials/02-fused-softmax-cpu.py ```
Tested on my M1 Mac as, ``` OMP_NUM_THREADS=8 \ TRITON_LOCAL_LIBOMP_PATH="<path..to>/site-packages/torch/" \ CC=$(which clang) \ TRITON_CPU_BACKEND=1 \ $(which python3) \ python/tutorials/02-fused-softmax-cpu.py ```
I know it's meant to be a hack, but a nicer long-term solution would be to handle this at install time Also I found https://mac.r-project.org/openmp/, we could possibly download libomp from there |
|
Tested as,