Skip to content

Commit 8ff3966

Browse files
authored
Merge pull request #1505 from kvcache-ai/support-qwen3next
fix qwen3next bug
2 parents 880daa7 + d9c75cb commit 8ff3966

File tree

4 files changed

+31
-12
lines changed

4 files changed

+31
-12
lines changed

doc/en/Qwen3-Next.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ To install KTransformers, follow the official [Installation Guide](https://kvcac
3838
python ktransformers/server/main.py \
3939
--port 10021 \
4040
--model_path path-to-Qwen3-Next-80B-A3B-Thinking \
41+
--gguf_path path-to-Qwen3-Next-80B-A3B-Thinking \
4142
--model_name Qwen3NextForCausalLM \
4243
--optimize_config_path <local_path>/ktransformers/optimize/optimize_rules/Qwen3Next-serve.yaml \
4344
--max_new_tokens 1024 \

ktransformers/models/modeling_qwen3_next.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,23 @@
4343
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
4444
from transformers.utils.deprecation import deprecate_kwarg
4545
from transformers.utils.generic import OutputRecorder, check_model_inputs
46-
from transformers.utils.import_utils import (
47-
is_causal_conv1d_available,
48-
is_flash_linear_attention_available,
49-
)
46+
try:
47+
from transformers.utils.import_utils import (
48+
is_causal_conv1d_available,
49+
is_flash_linear_attention_available,
50+
)
51+
except ImportError:
52+
is_causal_conv1d_available = lambda: False
53+
54+
55+
try:
56+
from transformers.utils.import_utils import (
57+
is_flash_linear_attention_available,
58+
)
59+
except ImportError:
60+
is_flash_linear_attention_available = lambda: False
61+
62+
5063
from .configuration_qwen3_next import Qwen3NextConfig
5164

5265

ktransformers/operators/balance_serve_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,11 +614,6 @@ def forward(self,
614614
query_states = self.q_norm(query_states, bsz_tensors)
615615
key_states = self.k_norm(key_states, bsz_tensors)
616616

617-
618-
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
619-
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
620-
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
621-
622617
# cos, sin = freqs_cis
623618
"""
624619
print(query_states.shape)
@@ -634,11 +629,16 @@ def forward(self,
634629
if freqs_cis is not None:
635630
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
636631

632+
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
633+
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
634+
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
635+
637636

638637
k_cache = kv_cache.get_k_cache(self.layer_idx)
639638
v_cache = kv_cache.get_v_cache(self.layer_idx)
640639

641640

641+
print(f"{k_cache.shape=}, {v_cache.shape=}, {query_states.shape=}, {key_states.shape=}, {value_states.shape=}")
642642
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
643643

644644

ktransformers/server/args.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ktransformers.util.utils import get_free_ports
44
from transformers import AutoConfig
55
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
6+
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
67
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
78
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
89

@@ -138,12 +139,16 @@ def parse_args(self):
138139
self.cfg.server_ip = args.host
139140
self.cfg.server_port = args.port
140141
self.cfg.user_force_think = args.force_think
142+
143+
144+
args.architectures = args.model_name
145+
141146
try:
142147
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
143148
except:
144-
try:
145-
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
146-
except:
149+
if args.model_name == "Qwen3NextForCausalLM":
150+
model_config = Qwen3NextConfig.from_pretrained(args.model_dir)
151+
else:
147152
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
148153

149154

0 commit comments

Comments
 (0)