Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ingest fix #340

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def forward(
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if fp8_comp_scales else None
fp8_out_scale.item()) if fp8_out_scale else None
out, _ = self.attn_func(
query,
key,
Expand Down
7 changes: 4 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix

self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

Expand All @@ -143,8 +144,7 @@ def forward(
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)

self.calc_kv_scales(query, key, value)
if self.use_direct_call:
return self.impl.forward(query,
key,
Expand Down Expand Up @@ -176,7 +176,8 @@ def forward(
kv_cache, attn_type,
self.layer_name)

def calc_kv_scales(self, key, value):
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
# We only calculate the scales once
Expand Down
19 changes: 13 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@
VLLM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
Q_SCALE_CONSTANT: int = 20
K_SCALE_CONSTANT: int = 20
V_SCALE_CONSTANT: int = 10


def get_default_cache_root():
Expand Down Expand Up @@ -529,13 +530,19 @@ def get_default_config_root():
"VLLM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))),

# Divisor for dynamic key scale factor calculation for FP8 KV Cache
# Divisor for dynamic query scale factor calculation for FP8 attention
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")),

# Divisor for dynamic key scale factor calculation
# for FP8 KV Cache and attention
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
lambda: int(os.getenv("K_SCALE_CONSTANT", "20")),

# Divisor for dynamic value scale factor calculation for FP8 KV Cache
# Divisor for dynamic value scale factor calculation
# for FP8 KV Cache and attention
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
lambda: int(os.getenv("V_SCALE_CONSTANT", "10")),

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
Expand Down
79 changes: 39 additions & 40 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,52 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor:
f"{self.__class__.__name__}.apply should not be called.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
k_scale *= 2
v_scale *= 2
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
k_scale *= 2
v_scale *= 2
layer.calculate_kv_scales = False
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
k_scale *= 2
v_scale *= 2
layer.calculate_kv_scales = False

if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")

# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
if (k_scale == 1.0 and v_scale == 1.0
and (layer.kv_cache_dtype != "auto" or layer.use_fp8)
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
if (k_scale == 1.0 and v_scale == 1.0
and (layer.kv_cache_dtype != "auto" or layer.use_fp8)
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

if layer.q_scale > 0.0 and layer.prob_scale > 0.0:
q_scale = layer.q_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
q_scale *= 2
layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def __init__(
self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config) \
and hasattr(self.o_proj, "input_scale") \
and self.o_proj.input_scale is not None
and isinstance(quant_config, Fp8Config)

self.attn = Attention(
self.num_heads,
Expand Down
Loading