Skip to content

Commit

Permalink
hide the bitblas import
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 20, 2024
1 parent 8e1a7e8 commit 7fbbccf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 21 deletions.
7 changes: 3 additions & 4 deletions benchmarks/kernels/benchmark_bitblas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from vllm.model_executor.layers.quantization.bitblas import (BITBLAS_TARGET,
Matmul,
MatmulConfig)
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target

from vllm.utils import FlexibleArgumentParser

parser = FlexibleArgumentParser(
Expand All @@ -13,7 +12,7 @@
parser.add_argument(
"--target",
type=str,
default=BITBLAS_TARGET,
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.",
)
parser.add_argument("--group_size",
Expand Down
22 changes: 5 additions & 17 deletions vllm/model_executor/layers/quantization/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,6 @@

logger = init_logger(__name__)

try:
import bitblas
from bitblas import Matmul, MatmulConfig
from bitblas.utils import auto_detect_nvidia_target
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import dependencies"
f"with the following error: {bitblas_import_exception}"
) from bitblas_import_exception

BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path()

BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
BITBLAS_SUPPORTED_SYM = [False, True]

Expand Down Expand Up @@ -468,7 +454,7 @@ def _configure_bitblas_matmul(
bits,
out_dtype="float16",
):

from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]

if self.quant_config.quant_method == "gptq":
Expand Down Expand Up @@ -509,8 +495,10 @@ def _configure_bitblas_matmul(
matmul_config, enable_tuning)

def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas.cache import global_operator_cache

from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
Expand Down

0 comments on commit 7fbbccf

Please sign in to comment.