Skip to content

Commit

Permalink
[Misc] add w8a8 asym models (vllm-project#11075)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored Dec 23, 2024
1 parent b866cdb commit 8cef6e0
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,23 @@ def zp_valid(zp: Optional[torch.Tensor]):
assert output


@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
# TODO static & asymmetric
])
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
example_prompts, model_path,
max_tokens, num_logprobs):
dtype = "bfloat16"

# skip language translation prompt for the static per tensor asym model
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501
example_prompts = example_prompts[0:-1]

with hf_runner(model_path, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
Expand Down

0 comments on commit 8cef6e0

Please sign in to comment.