Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Add support for tensor conversion from fp16 to fp32 using ExtFOp #3874

Closed
wants to merge 1 commit into from

Conversation

sasha0552
Copy link

@sasha0552 sasha0552 commented May 10, 2024

This PR fixes
Unsupported conversion from f16 to f16
LLVM ERROR: Unsupported rounding mode for conversion.
on Pascal GPUs.

@sasha0552 sasha0552 requested a review from ptillet as a code owner May 10, 2024 07:42
@sasha0552 sasha0552 changed the title Fix F16 -> F32 upcasting to support Pascal GPUs [NVIDIA] Fix F16 -> F32 upcasting to support Pascal GPUs May 10, 2024
@Jokeren
Copy link
Contributor

Jokeren commented May 10, 2024

I'm not sure if we are interested in supporting Pascal

NVIDIA GPUs (Compute Capability 7.0+)

@ptillet
Copy link
Collaborator

ptillet commented May 10, 2024

In principle I am not opposed to fixing up pascal issues if the added complexity is minimal, but here could you elaborate a little bit more on why the ampere code path is failing?

@sasha0552
Copy link
Author

sasha0552 commented May 10, 2024

First, Pascal GPUs have very poor performance in fp16, and there is no bf16. (I think maybe it's important).

The crash happens in tl.dot, which generates tt.fp_to_fp (I think it's because of out_dtype=float32, because if you use fp32 tensors as the input, everything works fine.). There is no support for fp16 -> fp32 upcasting in that (fp_to_fp) operation.

https://github.com/openai/triton/blob/a263360050e1887a2cda0c2cac811ddd3ccaab1e/python/triton/language/core.py#L1505

As for the fpext that @ThomasRaoux mentioned, this can be related:
https://github.com/openai/triton/blob/161f7a48d33275917372c3113aa70949c5bee9a8/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#L340-L345
I'll try adding the conversion using fpext here instead of what was suggested in this PR.

Code
import torch

import triton
import triton.language as tl

@triton.jit
def test_dot_kernel():
    t1 = tl.zeros([16, 16], dtype=tl.float16)
    t2 = tl.zeros([16, 16], dtype=tl.float16)
    d = tl.dot(t1, t2)
    tl.device_print("dot:", d)

grid = lambda meta: (1, )
kernel = test_dot_kernel[grid]()
Generated IR (?)
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:61", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @test_dot_kernel() attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc1)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc1)
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc1)
    %0 = tt.fp_to_fp %cst_0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc2)
    %1 = tt.fp_to_fp %cst_1 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc2)
    %2 = tt.dot %0, %1, %cst, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> loc(#loc2)
    tt.print " dot: " {hex = false} : %2 : tensor<16x16xf32, #blocked> loc(#loc3)
    tt.return loc(#loc4)
  } loc(#loc)
} loc(#loc)
#loc = loc("/tmp/repro-1.py":7:0)
#loc1 = loc(unknown)
#loc2 = loc("/tmp/repro-1.py":10:19)
#loc3 = loc("/tmp/repro-1.py":11:28)
#loc4 = loc("/tmp/repro-1.py":11:4)
Stacktrace
Unsupported conversion from f16 to f16
LLVM ERROR: Unsupported rounding mode for conversion.
*** SIGABRT received at time=1715319299 on cpu 0 ***
PC: @     0x787e96bdb32c  (unknown)  (unknown)
    @     0x787e96b8a770      31888  (unknown)
    @     0x5dd359f26fb0  (unknown)  (unknown)
[2024-05-10 05:34:59,869 E 1941 2481] logging.cc:365: *** SIGABRT received at time=1715319299 on cpu 0 ***
[2024-05-10 05:34:59,870 E 1941 2481] logging.cc:365: PC: @     0x787e96bdb32c  (unknown)  (unknown)
[2024-05-10 05:34:59,871 E 1941 2481] logging.cc:365:     @     0x787e96b8a770      31888  (unknown)
[2024-05-10 05:34:59,872 E 1941 2481] logging.cc:365:     @     0x5dd359f26fb0  (unknown)  (unknown)
Fatal Python error: Aborted

Stack (most recent call first):
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 212 in make_llir
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 302 in <lambda>
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 282 in compile
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 662 in run
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 345 in <lambda>
  File "/mnt/ml/vllm/vllm/attention/ops/prefix_prefill.py", line 744 in context_attention_fwd
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115 in decorate_context
  File "/mnt/ml/vllm/vllm/attention/ops/paged_attn.py", line 177 in forward_prefix
  File "/mnt/ml/vllm/vllm/attention/backends/xformers.py", line 236 in forward
  File "/mnt/ml/vllm/vllm/attention/layer.py", line 48 in forward
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541 in _call_impl
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532 in _wrapped_call_impl
  File "/mnt/ml/vllm/vllm/model_executor/models/llama.py", line 167 in forward
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541 in _call_impl
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532 in _wrapped_call_impl
  File "/mnt/ml/vllm/vllm/model_executor/models/llama.py", line 233 in forward
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541 in _call_impl
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532 in _wrapped_call_impl
  File "/mnt/ml/vllm/vllm/model_executor/models/llama.py", line 291 in forward
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541 in _call_impl
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532 in _wrapped_call_impl
  File "/mnt/ml/vllm/vllm/model_executor/models/llama.py", line 364 in forward
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541 in _call_impl
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532 in _wrapped_call_impl
  File "/mnt/ml/vllm/vllm/worker/model_runner.py", line 809 in execute_model
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115 in decorate_context
  File "/mnt/ml/vllm/vllm/worker/worker.py", line 254 in execute_model
  File "/mnt/ml/vllm/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115 in decorate_context
  File "/mnt/ml/vllm/vllm/worker/worker_base.py", line 137 in execute_method
  File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58 in run
  File "/usr/lib/python3.11/concurrent/futures/thread.py", line 83 in _worker
  File "/usr/lib/python3.11/threading.py", line 982 in run
  File "/usr/lib/python3.11/threading.py", line 1045 in _bootstrap_inner
  File "/usr/lib/python3.11/threading.py", line 1002 in _bootstrap

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, charset_normalizer.md, yaml._yaml, psutil._psutil_linux, psutil._psutil_posix, sentencepiece._sentencepiece, msgpack._cmsgpack, google._upb._message, setproctitle, uvloop.loop, ray._raylet, ujson, regex._regex, scipy._lib._ccallback_c, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, markupsafe._speedups, PIL._imaging, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy._lib.messagestream, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.spatial._ckdtree, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._cdflib, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.optimize._direct, httptools.parser.parser, httptools.parser.url_parser, websockets.speedups, cuda_utils (total: 104)

@sasha0552 sasha0552 force-pushed the pascal-f16tof32-upcast branch from 13dea63 to 5d76f19 Compare May 11, 2024 05:34
@sasha0552 sasha0552 changed the title [NVIDIA] Fix F16 -> F32 upcasting to support Pascal GPUs [NVIDIA] Add support for tensor conversion from fp16 to fp32 using ExtFOp May 11, 2024
@sasha0552 sasha0552 requested a review from ThomasRaoux May 11, 2024 05:50
@ptillet
Copy link
Collaborator

ptillet commented May 13, 2024

Unfortunately, this is precisely the kind of workarounds we meant to avoid when we dropped support for pre-A100 GPUs. On all GPUs supported by Triton, dot should not implicitly upcast its operands to float32. I am sure that there are many, many other rough edges for pascal (especially when it comes to transpositions), and we can't commit to accepting all the PRs that fix them. Thanks for your understanding!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants