From 7cbe33b5c0bc4fbff167092c2c646cef81823459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 24 Jun 2024 15:11:49 +0200 Subject: [PATCH] Use GPTQ-Marlin for supported GPTQ configurations GPTQ-Marlin is currently the best-performing kernel for GPTQ models. So let's use it by default if the kernels are installed, the GPU supports it, and the kernels support the configuration. For models generated by `text-generation-server quantize`, use `sym=False`. This subcommand symmetric quantization since the beginning and incorrectly reporting the model to be symmetric will use GPTQ-Marlin (which does not support asymmetric quantization). --- .../test_flash_llama_gptq_marlin.json | 84 ----- ...st_flash_llama_gptq_marlin_all_params.json | 84 ----- .../test_flash_llama_gptq_marlin_load.json | 338 ------------------ .../models/test_flash_llama_gptq_marlin.py | 65 ---- .../layers/gptq/__init__.py | 9 + .../text_generation_server/layers/linear.py | 63 ++-- .../text_generation_server/layers/marlin.py | 14 + .../text_generation_server/utils/weights.py | 257 ++++++------- 8 files changed, 162 insertions(+), 752 deletions(-) delete mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json delete mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json delete mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json delete mode 100644 integration-tests/models/test_flash_llama_gptq_marlin.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json deleted file mode 100644 index 0f99d2597e5..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6230469, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.046875, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1425781, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.9238281, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.076660156, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10821533, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" -} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json deleted file mode 100644 index 4152b5b308b..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": 0, - "tokens": [ - { - "id": 13, - "logprob": -2.2539062, - "special": false, - "text": "." - }, - { - "id": 578, - "logprob": -0.15563965, - "special": false, - "text": " The" - }, - { - "id": 3622, - "logprob": -0.8203125, - "special": false, - "text": " server" - }, - { - "id": 706, - "logprob": 0.0, - "special": false, - "text": " has" - }, - { - "id": 539, - "logprob": 0.0, - "special": false, - "text": " not" - }, - { - "id": 3686, - "logprob": 0.0, - "special": false, - "text": " yet" - }, - { - "id": 3288, - "logprob": 0.0, - "special": false, - "text": " sent" - }, - { - "id": 904, - "logprob": 0.0, - "special": false, - "text": " any" - }, - { - "id": 828, - "logprob": 0.0, - "special": false, - "text": " data" - }, - { - "id": 382, - "logprob": -1.5517578, - "special": false, - "text": ".\n\n" - } - ], - "top_tokens": null - }, - "generated_text": "Test request. The server has not yet sent any data.\n\n" -} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json deleted file mode 100644 index 75e903033c4..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json +++ /dev/null @@ -1,338 +0,0 @@ -[ - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - } -] diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py deleted file mode 100644 index 9c37a64468c..00000000000 --- a/integration-tests/models/test_flash_llama_gptq_marlin.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest - - -@pytest.fixture(scope="module") -def flash_llama_gptq_marlin_handle(launcher): - with launcher( - "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin" - ) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): - await flash_llama_gptq_marlin_handle.health(300) - return flash_llama_gptq_marlin_handle.client - - -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): - response = await flash_llama_gptq_marlin.generate( - "Test request", max_new_tokens=10, decoder_input_details=True - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_all_params( - flash_llama_gptq_marlin, response_snapshot -): - response = await flash_llama_gptq_marlin.generate( - "Test request", - max_new_tokens=10, - repetition_penalty=1.2, - return_full_text=True, - temperature=0.5, - top_p=0.9, - top_k=10, - truncate=5, - typical_p=0.9, - watermark=True, - decoder_input_details=True, - seed=0, - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_load( - flash_llama_gptq_marlin, generate_load, response_snapshot -): - responses = await generate_load( - flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4 - ) - - assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) - - assert responses == response_snapshot diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 1172775f096..e4060d5574b 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -7,6 +7,15 @@ ) +@dataclass +class GPTQParams: + bits: int + groupsize: int + desc_act: bool + quant_method: str + sym: bool + + @dataclass class GPTQWeight: qweight: torch.Tensor diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index d40b192f653..ab1f0c6c0e0 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -167,35 +167,42 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import GPTQMarlinWeight - if not isinstance(weight, GPTQWeight): - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." + if isinstance(weight, GPTQMarlinWeight): + linear = GPTQMarlinLinear( + weight=weight, + bias=bias, ) - - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + elif isinstance(weight, GPTQWeight): + if weight.use_exllama: + try: + from text_generation_server.layers.gptq import ( + ExllamaQuantLinear, + ) + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + linear = ExllamaQuantLinear(weight, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + linear = QuantLinear( + weight.qweight, + weight.qzeros, + weight.scales, + weight.g_idx, + bias, + weight.bits, + weight.groupsize, ) - - linear = ExllamaQuantLinear(weight, bias) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." ) + elif quantize == "awq": from text_generation_server.layers.gptq import GPTQWeight @@ -225,17 +232,11 @@ def get_linear(weight, bias, quantize): ) elif quantize == "marlin": from text_generation_server.layers.marlin import ( - GPTQMarlinWeight, MarlinLinear, MarlinWeight, ) - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, MarlinWeight): + if isinstance(weight, MarlinWeight): linear = MarlinLinear(weight=weight, bias=bias) else: raise NotImplementedError( diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 4d4c635ecf9..5152c64df7f 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.import_utils import SYSTEM try: @@ -23,6 +24,19 @@ MARLIN_TILE_SIZE = 16 +def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: + return ( + SYSTEM == "cuda" + and marlin_kernels is not None + and has_sm_8_0 + and quantize == "gptq" + and gptq_params.quant_method == "gptq" + and gptq_params.bits in GPTQ_MARLIN_BITS + and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES + and gptq_params.sym + ) + + def _check_marlin_kernels(): if not (SYSTEM == "cuda" and has_sm_8_0): raise NotImplementedError( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index e61425254a0..5b01d2755a3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,24 +1,15 @@ import os -from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.log import log_once -@dataclass -class _GPTQParams: - bits: int - groupsize: int - desc_act: bool - quant_method: str - sym: bool - - class Weights: def __init__( self, @@ -211,6 +202,10 @@ def get_weights_col_packed( """ if quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) try: qweight = self.get_packed_sharded( @@ -220,17 +215,28 @@ def get_weights_col_packed( raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=self.dtype) gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + g_idx = self.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) qzeros = self.get_packed_sharded( f"{prefix}.qzeros", dim=1, block_sizes=block_sizes ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) - if quantize == "gptq" and gptq_params.quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") elif quantize == "gptq" and gptq_params.quant_method == "awq": @@ -262,46 +268,11 @@ def get_weights_col_packed( use_exllama=False, ) elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - MarlinWeight, - repack_gptq_for_marlin, - ) - - quant_method = getattr(self, "quant_method", "marlin") - if quant_method == "gptq": - gptq_params = self._get_gptq_params() - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - g_idx = self.get_tensor(f"{prefix}.g_idx") - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) + from text_generation_server.layers.marlin import MarlinWeight - else: - B = self.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) + B = self.get_packed_sharded(f"{prefix}.B", dim=1, block_sizes=block_sizes) + s = self.get_packed_sharded(f"{prefix}.s", dim=1, block_sizes=block_sizes) + weight = MarlinWeight(B=B, s=s) else: weight = self.get_packed_sharded( f"{prefix}.weight", dim=0, block_sizes=block_sizes @@ -339,6 +310,10 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") elif quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) try: qweight = torch.cat( @@ -349,14 +324,31 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): f"Cannot load `{quantize}` weight, make sure the model is already quantized" ) - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) scales = torch.cat( [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) from text_generation_server.layers.gptq import HAS_EXLLAMA @@ -407,54 +399,19 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.marlin import ( MarlinWeight, - repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") - if quant_method == "gptq": - gptq_params = self._get_gptq_params() - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], - dim=1, - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, + try: + B = torch.cat( + [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 ) - else: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight, make sure the model is already quantized" ) + s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1) - weight = MarlinWeight(B=B, s=s) + weight = MarlinWeight(B=B, s=s) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -503,9 +460,41 @@ def get_multi_weights_row(self, prefix: str, quantize: str): ) elif quantize == "gptq": - use_exllama = True + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + if gptq_params.desc_act or gptq_params.groupsize == -1: + scales = self.get_tensor(f"{prefix}.scales") + else: + scales = self.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = self.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True if gptq_params.bits != 4: use_exllama = False @@ -630,66 +619,34 @@ def get_multi_weights_row(self, prefix: str, quantize: str): from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.marlin import ( MarlinWeight, - repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") - if quant_method == "gptq": - log_once(logger.info, "Converting GPTQ model to Marlin packing format.") - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, + try: + B = self.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - else: - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) + num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = self.get_tensor(f"{prefix}.s") + else: + s = self.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> _GPTQParams: + def _get_gptq_params(self) -> GPTQParams: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() desc_act = False - sym = True + sym = False quant_method = "gptq" except (SafetensorError, RuntimeError) as e: try: @@ -701,7 +658,7 @@ def _get_gptq_params(self) -> _GPTQParams: except Exception: raise e - return _GPTQParams( + return GPTQParams( bits=bits, desc_act=desc_act, groupsize=groupsize,