diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 30bbf5107475d..9c5212ace1346 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -7,6 +7,7 @@ from typing import Callable, List, Tuple, TypeVar import pynvml +import torch from typing_extensions import ParamSpec from vllm.logger import init_logger @@ -26,6 +27,10 @@ " and cause errors. See https://pypi.org/project/pynvml " "for more information.") +# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models +# see https://github.com/huggingface/diffusers/issues/9704 for details +torch.backends.cuda.enable_cudnn_sdp(False) + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids.