Skip to content

Commit

Permalink
import fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 20, 2024
1 parent 7fbbccf commit c487e69
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 33 deletions.
21 changes: 5 additions & 16 deletions vllm/model_executor/layers/quantization/bitnet_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from enum import Enum
from typing import Any, Dict, List, Optional

import bitblas.cache
import torch
from torch.nn.parameter import Parameter

Expand All @@ -16,20 +15,6 @@

logger = init_logger(__name__)

try:
import bitblas
from bitblas.utils import auto_detect_nvidia_target
except ImportError as e:
bitblas_import_exception = e
error_message = (
"Trying to use the bitblas backend, but could not import dependencies "
f"with the following error: {bitblas_import_exception}")
raise ValueError(error_message) from bitblas_import_exception

bitblas.set_log_level("Debug")
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path()

BITNET_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]


Expand Down Expand Up @@ -285,7 +270,11 @@ def _configure_bitblas_matmul(

def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul
from bitblas.cache import global_operator_cache
from bitblas.cache import get_database_path, global_operator_cache
from bitblas.utils import auto_detect_nvidia_target

BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()

if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
Expand Down
21 changes: 4 additions & 17 deletions vllm/model_executor/layers/quantization/gptq_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from enum import Enum
from typing import Any, Dict, List, Optional

import bitblas.cache
import torch
from torch.nn.parameter import Parameter

Expand All @@ -14,20 +13,6 @@

logger = init_logger(__name__)

try:
import bitblas
from bitblas.utils import auto_detect_nvidia_target
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import dependencies"
f"with the following error: {bitblas_import_exception}"
) from bitblas_import_exception

bitblas.set_log_level("Debug")
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path()

GPTQ_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
GPTQ_BITBLAS_SUPPORTED_SYM = [False, True]

Expand Down Expand Up @@ -415,8 +400,10 @@ def _configure_bitblas_matmul(
matmul_config, enable_tuning)

def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul
from bitblas.cache import global_operator_cache
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()

if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
Expand Down

0 comments on commit c487e69

Please sign in to comment.