diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 83fdef16ef5cb..a8a78d41c666c 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -130,7 +130,7 @@ def __init__( assert tpu_type is not None tpu_type = tpu_type.lower() - if "lite" not in tpu_type: + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" else: