Skip to content

Commit

Permalink
MoE Marlin: support desc_act for groupsize != -1 (#2590)
Browse files Browse the repository at this point in the history
This change uses the updated Marlin MoE kernel from vLLM to support
MoE with activation sorting and groups.
  • Loading branch information
danieldk authored Sep 30, 2024
1 parent d1f257a commit 1c84a30
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 19 deletions.
7 changes: 4 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix";
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
Expand Down
3 changes: 1 addition & 2 deletions server/text_generation_server/layers/marlin/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def get_weights_col_packed(
prefix: str,
block_sizes: Union[int, List[int]],
):

try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
Expand Down Expand Up @@ -352,7 +351,7 @@ def repack_gptq_for_marlin(

scales = permute_scales(scales)

is_full_k = not (desc_act and sharded_infeatures)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)

return GPTQMarlinWeight(
qweight=repacked,
Expand Down
3 changes: 0 additions & 3 deletions server/text_generation_server/layers/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,9 @@ def is_supported(weights: Weights) -> bool:
or (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
desc_act=weights.loader.desc_act,
groupsize=weights.loader.groupsize,
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
use_tp=weights.process_group.size() > 1,
)
)
)
10 changes: 0 additions & 10 deletions server/text_generation_server/layers/moe/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@

def can_use_marlin_moe_gemm(
*,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
use_tp: bool,
):
return (
SYSTEM == "cuda"
Expand All @@ -40,16 +37,9 @@ def can_use_marlin_moe_gemm(
and quantize == "gptq"
and quant_method == "gptq"
and sym
and is_full_k(desc_act, groupsize, use_tp)
)


def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
if groupsize == -1:
return True
return not (desc_act and use_tp)


@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor
Expand Down

0 comments on commit 1c84a30

Please sign in to comment.