Skip to content

Commit

Permalink
bugfix: workspace dir when no GPU is available (#579)
Browse files Browse the repository at this point in the history
When GPU is not available (e.g., `CUDA_VISIBLE_DEVICES=-1`),
`_get_cuda_arch_flags()` raises IndexError.

```
/opt/venv/lib/python3.10/site-packages/flashinfer/__init__.py:17: in <module>
    from .activation import (
/opt/venv/lib/python3.10/site-packages/flashinfer/activation.py:21: in <module>
    from .jit import (
/opt/venv/lib/python3.10/site-packages/flashinfer/jit/__init__.py:27: in <module>
    from .activation import (
/opt/venv/lib/python3.10/site-packages/flashinfer/jit/activation.py:19: in <module>
    from .env import FLASHINFER_GEN_SRC_DIR
/opt/venv/lib/python3.10/site-packages/flashinfer/jit/env.py:31: in <module>
    FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name()
/opt/venv/lib/python3.10/site-packages/flashinfer/jit/env.py:24: in _get_workspace_dir_name
    flags = _get_cuda_arch_flags()
/opt/venv/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1984: in _get_cuda_arch_flags
    arch_list[-1] += '+PTX'
E   IndexError: list index out of range
```

Although FlashInfer is not useful in this case, we still don't want to
crash user program when importing flashinfer. This PR fixes this issue.

Another change is to hide the warning of `TORCH_CUDA_ARCH_LIST` not set
when importing flashinfer. This can be annoying when using AOT wheels.
  • Loading branch information
abcdabcd987 authored Nov 2, 2024
1 parent fc0f6d4 commit c83cd6c
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/flashinfer/jit/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@

import pathlib
import re
import warnings

from torch.utils.cpp_extension import _get_cuda_arch_flags


def _get_workspace_dir_name() -> pathlib.Path:
flags = _get_cuda_arch_flags()
arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags)))))
try:
with warnings.catch_warnings():
# Ignore the warning for TORCH_CUDA_ARCH_LIST not set
warnings.filterwarnings(
"ignore", r".*TORCH_CUDA_ARCH_LIST.*", module="torch"
)
flags = _get_cuda_arch_flags()
arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags)))))
except Exception:
arch = "noarch"
# e.g.: $HOME/.cache/flashinfer/75_80_89_90/
return pathlib.Path.home() / ".cache" / "flashinfer" / arch

Expand Down

0 comments on commit c83cd6c

Please sign in to comment.