diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 04625a536df8..597eac3eef54 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import bitblas.cache import torch from torch.nn.parameter import Parameter @@ -16,20 +15,6 @@ logger = init_logger(__name__) -try: - import bitblas - from bitblas.utils import auto_detect_nvidia_target -except ImportError as e: - bitblas_import_exception = e - error_message = ( - "Trying to use the bitblas backend, but could not import dependencies " - f"with the following error: {bitblas_import_exception}") - raise ValueError(error_message) from bitblas_import_exception - -bitblas.set_log_level("Debug") -BITBLAS_TARGET = auto_detect_nvidia_target() -BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() - BITNET_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] @@ -285,7 +270,11 @@ def _configure_bitblas_matmul( def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul - from bitblas.cache import global_operator_cache + from bitblas.cache import get_database_path, global_operator_cache + from bitblas.utils import auto_detect_nvidia_target + + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 0dfd9d638944..22c6f2062e92 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import bitblas.cache import torch from torch.nn.parameter import Parameter @@ -14,20 +13,6 @@ logger = init_logger(__name__) -try: - import bitblas - from bitblas.utils import auto_detect_nvidia_target -except ImportError as e: - bitblas_import_exception = e - raise ValueError( - "Trying to use the bitblas backend, but could not import dependencies" - f"with the following error: {bitblas_import_exception}" - ) from bitblas_import_exception - -bitblas.set_log_level("Debug") -BITBLAS_TARGET = auto_detect_nvidia_target() -BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() - GPTQ_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] GPTQ_BITBLAS_SUPPORTED_SYM = [False, True] @@ -415,8 +400,10 @@ def _configure_bitblas_matmul( matmul_config, enable_tuning) def _get_or_create_bitblas_operator(self, config, enable_tuning): - from bitblas import Matmul - from bitblas.cache import global_operator_cache + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,