Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/unit_tests/ops/data/awq/input.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/output.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/qweight.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/qzeros.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/scales.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/linear_input.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/linear_output.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/linear_weight.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/moe_output.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/input.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/output.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/qweight.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/qzeros.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/scales.pt
Binary file not shown.
55 changes: 55 additions & 0 deletions tests/unit_tests/ops/test_hpu_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import habana_frameworks.torch as htorch
from utils import get_data_path
from vllm_gaudi.ops.hpu_awq import AWQHPULinearMethod, AWQHPUConfig
from vllm_gaudi.utils import HPUCompileConfig
from vllm.model_executor.layers.linear import RowParallelLinear


def test_awq_linear_method(dist_init):
config = {"bits": 4, "group_size": 128, "zero_point": True}
oot_quant_config = AWQHPUConfig.from_config(config)

# Prepare linear layer with oot AWQHPULinearMethod
oot_op = RowParallelLinear(input_size=256,
output_size=128,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
params_dtype=torch.bfloat16,
reduce_results=True,
quant_config=oot_quant_config,
return_bias=False,
disable_tp=False).to("hpu")
assert isinstance(oot_op.quant_method, AWQHPULinearMethod)

# qweight, qzeros, scales were extracted from first RowParallelLinear of TheBloke/Llama-2-7B-Chat-AWQ
# (with adjusted shape, to make tensors smaller)
qweight = torch.load(get_data_path("data/awq/qweight.pt"), weights_only=False, map_location="hpu")
oot_op.qweight.copy_(qweight)
qzeros = torch.load(get_data_path("data/awq/qzeros.pt"), weights_only=False, map_location="hpu")
oot_op.qzeros.copy_(qzeros)
scales = torch.load(get_data_path("data/awq/scales.pt"), weights_only=False, map_location="hpu").to(torch.bfloat16)
oot_op.scales.copy_(scales)

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds the data that was returned by cuda implementation of AWQLinearMethod for given input
# (AWQLinearMethod was triggered offline with the same input as below to get the ref_output)
input = torch.load(get_data_path("data/awq/input.pt"), weights_only=False, map_location="hpu").to(torch.bfloat16)
ref_output = torch.load(get_data_path("data/awq/output.pt"), weights_only=False,
map_location="hpu").to(torch.bfloat16)

# Execute layer
out = oot_op(input)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
140 changes: 140 additions & 0 deletions tests/unit_tests/ops/test_hpu_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import habana_frameworks.torch as htorch
from utils import get_data_path
from unittest.mock import MagicMock
from vllm_gaudi.ops.hpu_fp8 import Fp8LinearMethod, HPUFp8MoEMethod
from vllm_gaudi.utils import HPUCompileConfig
from vllm.forward_context import override_forward_context
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.fused_moe.layer import FusedMoE


def test_fp8_linear_method(dist_init, monkeypatch):
monkeypatch.setenv("VLLM_HPU_FORCE_CHANNEL_FP8", "0")
config = {'activation_scheme': 'dynamic', 'fmt': 'e4m3', 'quant_method': 'fp8', 'weight_block_size': [128, 128]}
oot_quant_config = Fp8Config.from_config(config)

# Prepare linear layer with oot Fp8LinearMethod
oot_op = RowParallelLinear(input_size=256,
output_size=256,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
params_dtype=torch.bfloat16,
reduce_results=True,
quant_config=oot_quant_config,
return_bias=False,
disable_tp=False).to("hpu")
assert isinstance(oot_op.quant_method, Fp8LinearMethod)

# Load weight and weight_scale_inv were extracted from first RowParallelLinear layer of Qwen/Qwen3-8B-FP8
# (with adjusted shapes, to make tensors smaller)
weight = torch.load(get_data_path("data/fp8/linear_weight.pt"), weights_only=False, map_location="hpu")
oot_op.weight.copy_(weight)
weight_scale_inv = torch.load(get_data_path("data/fp8/linear_weight_scale_inv.pt"),
weights_only=False,
map_location="hpu")
oot_op.weight_scale_inv.copy_(weight_scale_inv)

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
# Setting fullgraph to False, because currently there is a graph break
compile_config = HPUCompileConfig(fullgraph=False)
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds the data that was returned by cuda implementation of Fp8LinearMethod for given input
# (Fp8LinearMethod was triggered offline with the same input as below to get the ref_output)
input = torch.load(get_data_path("data/fp8/linear_input.pt"), weights_only=False, map_location="hpu")
ref_output = torch.load(get_data_path("data/fp8/linear_output.pt"), weights_only=False, map_location="hpu")

# Execute layer
out = oot_op(input)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-2, rtol=1e-2)


def test_fp8_moe_method(dist_init, monkeypatch):
monkeypatch.setenv("VLLM_HPU_FORCE_CHANNEL_FP8", "0")
config = {
'activation_scheme': 'dynamic',
'modules_to_not_convert': [],
'fmt': 'e4m3',
'quant_method': 'fp8',
'weight_block_size': [128, 128]
}
oot_quant_config = Fp8Config.from_config(config)

# Prepare FusedMoE layer with oot HPUFp8MoEMethod
oot_op = FusedMoE(num_experts=128,
top_k=8,
hidden_size=512,
intermediate_size=256,
params_dtype=torch.bfloat16,
reduce_results=True,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=oot_quant_config,
tp_size=None,
ep_size=None,
dp_size=None,
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_factor=1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation="silu",
enable_eplb=False,
num_redundant_experts=0,
has_bias=False,
is_sequence_parallel=False).to("hpu")
assert isinstance(oot_op.quant_method, HPUFp8MoEMethod)

# Weights were extracted from first FusedMoE layer of Qwen/Qwen3-30B-A3B-FP8
# (with adjusted shapes, to make tensors smaller)
w13_weight = torch.load(get_data_path("data/fp8/moe_w13_weight.pt"), weights_only=False, map_location="hpu")
oot_op.w13_weight.copy_(w13_weight.repeat(128, 1, 1))
w13_weight_scale_inv = torch.load(get_data_path("data/fp8/moe_w13_weight_scale_inv.pt"),
weights_only=False,
map_location="hpu")
oot_op.w13_weight_scale_inv.copy_(w13_weight_scale_inv.repeat(128, 1, 1))
w2_weight = torch.load(get_data_path("data/fp8/moe_w2_weight.pt"), weights_only=False, map_location="hpu")
oot_op.w2_weight.copy_(w2_weight.repeat(128, 1, 1))
w2_weight_scale_inv = torch.load(get_data_path("data/fp8/moe_w2_weight_scale_inv.pt"),
weights_only=False,
map_location="hpu")
oot_op.w2_weight_scale_inv.copy_(w2_weight_scale_inv.repeat(128, 1, 1))

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds the data that was returned by cuda implementation of Fp8MoEMethod for given input
# (Fp8MoEMethod was triggered offline with the same input as below to get the ref_output)
hidden_states = torch.load(get_data_path("data/fp8/moe_input_hidden_states.pt"),
weights_only=False,
map_location="hpu")
router_logits = torch.load(get_data_path("data/fp8/moe_input_router_logits.pt"),
weights_only=False,
map_location="hpu")
ref_output = torch.load(get_data_path("data/fp8/moe_output.pt"), weights_only=False, map_location="hpu")

# Execute layer
mock_ctx = MagicMock(spec=["dp_metadata"])
mock_ctx.dp_metadata = None
with override_forward_context(mock_ctx):
out = oot_op.forward_impl(hidden_states, router_logits)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
55 changes: 55 additions & 0 deletions tests/unit_tests/ops/test_hpu_gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import habana_frameworks.torch as htorch
from utils import get_data_path
from vllm_gaudi.ops.hpu_gptq import GPTQHPULinearMethod, GPTQHPUConfig
from vllm_gaudi.utils import HPUCompileConfig
from vllm.model_executor.layers.linear import RowParallelLinear


def test_gptq_linear_method(dist_init):
config = {"bits": 4, "group_size": 128, "desc_act": False, "lm_head": False}
oot_quant_config = GPTQHPUConfig.from_config(config)

# Prepare linear layer with oot GPTQHPULinearMethod
oot_op = RowParallelLinear(input_size=256,
output_size=8,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
params_dtype=torch.bfloat16,
reduce_results=True,
quant_config=oot_quant_config,
return_bias=False,
disable_tp=False).to("hpu")
assert isinstance(oot_op.quant_method, GPTQHPULinearMethod)

# qweight, qzeros, scales were extracted from first RowParallelLinear of TheBloke/Llama-2-7B-Chat-GPTQ
# (with adjusted shape, to make tensors smaller)
qweight = torch.load(get_data_path("data/gptq/qweight.pt"), weights_only=False, map_location="hpu")
oot_op.qweight.copy_(qweight)
qzeros = torch.load(get_data_path("data/gptq/qzeros.pt"), weights_only=False, map_location="hpu")
oot_op.qzeros.copy_(qzeros)
scales = torch.load(get_data_path("data/gptq/scales.pt"), weights_only=False, map_location="hpu").to(torch.bfloat16)
oot_op.scales.copy_(scales)

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds the data that was returned by cuda implementation of GPTQLinearMethod for given input
# (GPTQLinearMethod was triggered offline with the same input as below to get the ref_output)
input = torch.load(get_data_path("data/gptq/input.pt"), weights_only=False, map_location="hpu").to(torch.bfloat16)
ref_output = torch.load(get_data_path("data/gptq/output.pt"), weights_only=False,
map_location="hpu").to(torch.bfloat16)

# Execute layer
out = oot_op(input)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
71 changes: 71 additions & 0 deletions tests/unit_tests/ops/test_hpu_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch
import habana_frameworks.torch as htorch
from utils import temporary_op_registry_oot, register_op
from vllm_gaudi.ops.hpu_layernorm import HPURMSNorm
from vllm_gaudi.utils import HPUCompileConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform

DTYPES = [torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
ADD_RESIDUAL = [False, True]
DEVICE = [current_platform.device_type]
IS_STRIDED = [False, True]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICE)
@pytest.mark.parametrize("strided_input", IS_STRIDED)
def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
device: str,
strided_input: bool,
) -> None:
with temporary_op_registry_oot():
# prepare native RMSNorm module
native_rms_norm = RMSNorm(hidden_size=hidden_size, eps=1e-05)
native_rms_norm = native_rms_norm.to(dtype=dtype).to(device)
native_rms_norm.weight.data.normal_(mean=1.0, std=0.1)
assert isinstance(native_rms_norm, RMSNorm) and not isinstance(native_rms_norm, HPURMSNorm)

# Prepare oot HPURMSNorm module
register_op(RMSNorm, HPURMSNorm)
oot_rms_norm = RMSNorm(hidden_size=hidden_size, eps=1e-05)
oot_rms_norm = oot_rms_norm.to(dtype=dtype).to(device)
oot_rms_norm.weight.data = native_rms_norm.weight.data.clone()
assert isinstance(oot_rms_norm, RMSNorm) and isinstance(oot_rms_norm, HPURMSNorm)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_rms_norm = torch.compile(oot_rms_norm, **compile_config.get_compile_args())

# Prepare input data
scale = 1 / (2 * hidden_size)
last_dim = 2 * hidden_size if strided_input else hidden_size
x = torch.randn(num_tokens, last_dim, dtype=dtype, device=device)
x = x[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

# Execute layers
ref_out = native_rms_norm(x, residual)
out = oot_rms_norm(x, residual)

# Check correctness
if add_residual:
torch.testing.assert_close(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
Loading
Loading