-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for compute capability <7.0 #963
Comments
Hi @andersRuge, I believe there's no technical reason. We just thought that Pascal GPUs were not very popular these days, and we didn't have Pascal GPUs to test vLLM on. vLLM may work without any modification: Just try installing vLLM from source. git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install -e . |
Hi there, I can confirm that it's working with Pascal architecture (Quadro P2000, CUDA version 12.3) when built from source. Tested on a Dell Precision 5530 Laptop (with Python 3.9.18) git clone https://github.com/vllm-project/vllm.git
cd vllm
# The setup doesn't allow compute capability greater than 7.0 (lines 149,150,151 cause this because they limit the version explicitly)
mv setup.py _setup.py
# We use awk to recreate the file without that 'if' block in python
awk '!(NR == 151 || NR == 150 || NR == 149)' ./_setup.py > ./setup.py
pip install -e . Below, find the lines removed from the # First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 7: # REMOVE
raise RuntimeError( # REMOVE
"GPUs with compute capability below 7.0 are not supported.") # REMOVE
compute_capabilities.add(f"{major}.{minor}") Image of VLLM working on the machine mentioned: |
Thanks for this tip @klebster2 ! That is exactly what I needed. I was able to get vLLM to work with the current version (v0.2.7 / 220a476) in a Docker container. My test rig is Ubuntu 22.04, CUDA 12.1 and I started with a GTX 1060 and then tested on 4 x P100's for a larger model. Based on your notes, here is my How-to:
--- _setup.py 2024-01-27 18:44:45.509406538 +0000
+++ setup.py 2024-01-28 00:02:23.581639719 +0000
@@ -18,7 +18,7 @@
MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures.
-NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
+NVIDIA_SUPPORTED_ARCHS = {"6.0", "6.1", "6.2", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
@@ -184,9 +184,9 @@
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
- if major < 7:
+ if major < 6:
raise RuntimeError(
- "GPUs with compute capability below 7.0 are not supported.")
+ "GPUs with compute capability below 6.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}")
ext_modules = []
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
RUN apt-get update -y \
&& apt-get install -y python3-pip
WORKDIR /app
COPY . .
RUN python3 -m pip install -e .
EXPOSE 8001
COPY entrypoint.sh /usr/local/bin/
CMD [ "entrypoint.sh" ]
nvidia-docker run -d -p 8001:8001 --gpus=all --shm-size=10.24gb \
-e MODEL=mistralai/Mistral-7B-Instruct-v0.1 \
-e PORT=8001 \
-e HF_HOME=/app/models \
-e NUM_GPU=4 \
-e EXTRA_ARGS="--dtype float --max-model-len 20000" \
-v /path/to/models:/app/models \
--name vllm \
vllm Additional details here: https://github.com/jasonacox/TinyLLM/tree/main/vllm#running-vllm-on-pascal INFO 01-27 23:52:57 llm_engine.py:871] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 41.7 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.4%, CPU KV cache usage: 0.0% |
Great to hear it helped! |
What's the max tokens/s you can get on 4x P100s? |
Running a simple benchmark.py I get: Min TPS: 27.6, Max TPS: 40.1, Avg TPS: 33.9 |
Considering vllm is a batched inference server wouldn't you be able to still get more tokens/s running multiple of the benchmark scripts? Unless its already maxing out the P100s? Honestly 34tps is kinda low, I was hoping it would be a bit higher then it could be viable as a cheap inference GPU... |
You are right, I'm only running one thread with this test. I ran again with 3 benchmarks running simultaneously and all 3 stayed above 30 TPS (so that could be considered 90 TPS ?), but I don't know where the ceiling would be. If anyone has a good test script that pushes batching, would love to try it. UPDATE - I ran a MT benchmark with 10 threads and got 208 tokens/s. Each concurrent thread was seeing about 21.
|
That's decent speeds out of old cards actually. Did you have time to test larger models on it too? |
I would be happy to, any suggestions? On Mistral 7B Instruct, I'm running the full 32k context, using float16 (bfloat16 not available on Pascal?) and it is filling up most of the VRAM on all four P100 (each with 16G). Now for the other bit... there are 7 GPUs in this system, but vLLM will only split the model across 4 (to get an even split of the 32 layers I assume). Of course, I'm not wasting those GPUs. I'm using the other 3 GPUs for text2vec transformers and smaller models. But it would be fun to test a model on all 7. 😀 |
Would anyone here be willing to review this PR? #2635 |
What is the minimum VRAM required for running 7B models on the P100? If 2 cards are enough I might try and get 2x P100 to experiment with and try that PR. Also if you could test larger 34b models on more than 4 cards that would be awesome too since I can't even run it on vllm on 2x 3090s. |
I would only go with P100's if they are cheap. They only have 16G of VRAM and the Pascal architecture is at the bottom edge of CUDA support. The 7B model with 32k context ( |
So I have experimented with VLLM some more and I can run 70B AWQ 4-bit models on my 2x3090 with up to --max-model-len 16384 and --gpu-memory-utilization 0.98. Have you tried AWQ 4-bit models? I bet you can fit much more on the P100 since the AWQ 4-bit models use so much less VRAM than the full FP16 models do. In terms of performance I am getting this:
It seems to me like the limitation becomes the single core performance of my CPU when running 7B on my 3090s since its not much faster than my 3060 machine with a much less cores but much faster GHz CPU. |
Thanks @Nero10578 ! I'll give it a try. |
Just to add, I've tested vLLM with P100 and it works very well. If you want to use only a single card, then you can limit context size, or better still, use a quantized model and then it will run very fast as a bonus. |
I believe the latest release added optional support without needing to patch. Did you need to do anything special to get it to work? |
Oh awesome. I might get some of the P100s then. I was also asking the Aphrodite Engine devs if the P100 would work on it too since Aphrodite is an awesome fork of VLLM that supports more quantizations. |
Yes, I had to patch the source code. I can send a pull request. |
Please merge pascal support. Many of us are running P40 / P100 rigs (even built recently for the purpose) because they're a very good VRAM/$ deal given ebay prices and the limited number PCIe slots available on consumer rigs. |
Thanks! Or post it here? |
I posted the pull request: #4290 |
It looks like there's no way to make these changes as the setup.py doesn't seem to contain these checks anymore - is there a more up-to-date way to manually make vLLM compatible with the Pascal architecture (p40s) if this method doesn't work and it isn't going to be added to the main branch? |
Hi,
How tightly coupled is the requirement for compute capability of 7.0 or higher? Is it possible to disable some features, and run on e.g. 6.0? Like a P100
Maybe this is totally unfeasible, but I am limited in my GPU options.
The text was updated successfully, but these errors were encountered: