Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix gptq failure on T4s #7264

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):

return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
has_zp=has_zp)


class AWQMarlinLinearMethod(LinearMethodBase):
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
return False

return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size,
min_capability=cls.get_min_capability())
group_size=group_size)


class GPTQMarlinLinearMethod(LinearMethodBase):
Expand Down
23 changes: 12 additions & 11 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor

if min_capability < 80:
if device_capability < 80:
return []

if has_zp:
Expand All @@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:

if min_capability is None:
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor

supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
has_zp, device_capability)

if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
Expand All @@ -73,9 +74,9 @@ def _check_marlin_supported(
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
device_capability)
return cond


Expand Down
Loading