Skip to content

Commit ff26271

Browse files
committed
Qwen3 Next - cleaned up version
1 parent 9b17d74 commit ff26271

File tree

14 files changed

+1030
-13
lines changed

14 files changed

+1030
-13
lines changed

convert_hf_to_gguf.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4196,6 +4196,35 @@ def set_vocab(self):
41964196

41974197
super().set_vocab()
41984198

4199+
@ModelBase.register("Qwen3NextForCausalLM")
4200+
class Qwen3NextModel(Qwen3MoeModel):
4201+
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
4202+
4203+
def set_gguf_parameters(self):
4204+
super().set_gguf_parameters()
4205+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["linear_conv_kernel_dim"]))
4206+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["linear_key_head_dim"]))
4207+
self.gguf_writer.add_ssm_group_count(self.find_hparam(["linear_num_key_heads"]))
4208+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["linear_num_value_heads"]))
4209+
self.gguf_writer.add_ssm_inner_size(self.find_hparam(['linear_value_head_dim']) * self.find_hparam(['linear_num_value_heads']))
4210+
if (rope_dim := self.hparams.get("head_dim")) is None:
4211+
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
4212+
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
4213+
4214+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4215+
if name.startswith("mtp"):
4216+
return [] # ignore MTP layers for now
4217+
if name.endswith(".A_log"):
4218+
data_torch = -torch.exp(data_torch)
4219+
elif name.endswith(".dt_bias"):
4220+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
4221+
elif "conv1d" in name:
4222+
data_torch = data_torch.squeeze()
4223+
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
4224+
data_torch = data_torch + 1
4225+
4226+
yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
4227+
41994228

42004229
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
42014230
class Qwen3VLVisionModel(MmprojModel):

examples/model-conversion/scripts/causal/run-converted-model.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ set -e
44

55
# First try command line argument, then environment variable, then file
66
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
7+
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
8+
9+
if [ -z "$MODEL_TESTING_PROMPT"]; then
10+
MODEL_TESTING_PROMPT="Hello, my name is"
11+
fi
712

813
# Final check if we have a model path
914
if [ -z "$CONVERTED_MODEL" ]; then
@@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
1419
fi
1520

1621
echo $CONVERTED_MODEL
22+
echo $MODEL_TESTING_PROMPT
1723

1824
cmake --build ../../build --target llama-logits -j8
1925

20-
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
26+
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,12 @@ def fn(_m, input, output):
184184
# of using AutoModelForCausalLM.
185185
print(f"Model class: {model.__class__.__name__}")
186186

187-
prompt = "Hello, my name is"
188-
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
187+
device = next(model.parameters()).device
188+
if os.getenv("MODEL_TESTING_PROMPT"):
189+
prompt = os.getenv("MODEL_TESTING_PROMPT")
190+
else:
191+
prompt = "Hello, my name is"
192+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
189193

190194
print(f"Input tokens: {input_ids}")
191195
print(f"Input text: {repr(prompt)}")

gguf-py/gguf/constants.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ class MODEL_ARCH(IntEnum):
352352
QWEN2VL = auto()
353353
QWEN3 = auto()
354354
QWEN3MOE = auto()
355+
QWEN3NEXT = auto()
355356
QWEN3VL = auto()
356357
QWEN3VLMOE = auto()
357358
PHI2 = auto()
@@ -516,6 +517,7 @@ class MODEL_TENSOR(IntEnum):
516517
SSM_D = auto()
517518
SSM_NORM = auto()
518519
SSM_OUT = auto()
520+
SSM_BETA_ALPHA = auto() # qwen3next
519521
TIME_MIX_W0 = auto()
520522
TIME_MIX_W1 = auto()
521523
TIME_MIX_W2 = auto()
@@ -721,6 +723,7 @@ class MODEL_TENSOR(IntEnum):
721723
MODEL_ARCH.QWEN2VL: "qwen2vl",
722724
MODEL_ARCH.QWEN3: "qwen3",
723725
MODEL_ARCH.QWEN3MOE: "qwen3moe",
726+
MODEL_ARCH.QWEN3NEXT: "qwen3next",
724727
MODEL_ARCH.QWEN3VL: "qwen3vl",
725728
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
726729
MODEL_ARCH.PHI2: "phi2",
@@ -884,6 +887,7 @@ class MODEL_TENSOR(IntEnum):
884887
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
885888
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
886889
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
890+
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
887891
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
888892
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
889893
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
@@ -1553,6 +1557,35 @@ class MODEL_TENSOR(IntEnum):
15531557
MODEL_TENSOR.FFN_DOWN_EXP,
15541558
MODEL_TENSOR.FFN_UP_EXP,
15551559
],
1560+
MODEL_ARCH.QWEN3NEXT: [
1561+
MODEL_TENSOR.TOKEN_EMBD,
1562+
MODEL_TENSOR.OUTPUT_NORM,
1563+
MODEL_TENSOR.OUTPUT,
1564+
MODEL_TENSOR.ATTN_NORM,
1565+
MODEL_TENSOR.ATTN_Q,
1566+
MODEL_TENSOR.ATTN_Q_NORM,
1567+
MODEL_TENSOR.ATTN_K,
1568+
MODEL_TENSOR.ATTN_K_NORM,
1569+
MODEL_TENSOR.ATTN_V,
1570+
MODEL_TENSOR.ATTN_OUT,
1571+
MODEL_TENSOR.ATTN_POST_NORM,
1572+
MODEL_TENSOR.ATTN_GATE,
1573+
MODEL_TENSOR.FFN_GATE_INP,
1574+
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
1575+
MODEL_TENSOR.FFN_UP_SHEXP,
1576+
MODEL_TENSOR.FFN_DOWN_SHEXP,
1577+
MODEL_TENSOR.FFN_GATE_SHEXP,
1578+
MODEL_TENSOR.FFN_DOWN_EXP,
1579+
MODEL_TENSOR.FFN_UP_EXP,
1580+
MODEL_TENSOR.FFN_GATE_EXP,
1581+
MODEL_TENSOR.SSM_A,
1582+
MODEL_TENSOR.SSM_CONV1D,
1583+
MODEL_TENSOR.SSM_DT,
1584+
MODEL_TENSOR.SSM_NORM,
1585+
MODEL_TENSOR.SSM_IN,
1586+
MODEL_TENSOR.SSM_BETA_ALPHA,
1587+
MODEL_TENSOR.SSM_OUT
1588+
],
15561589
MODEL_ARCH.QWEN3VL: [
15571590
MODEL_TENSOR.TOKEN_EMBD,
15581591
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,17 +672,19 @@ class TensorNameMap:
672672
),
673673

674674
MODEL_TENSOR.SSM_IN: (
675-
"model.layers.{bid}.in_proj", # mamba-hf
676-
"backbone.layers.{bid}.mixer.in_proj", # mamba
677-
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
678-
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
675+
"model.layers.{bid}.in_proj", # mamba-hf
676+
"backbone.layers.{bid}.mixer.in_proj", # mamba
677+
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
678+
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
679+
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
679680
),
680681

681682
MODEL_TENSOR.SSM_CONV1D: (
682683
"model.layers.{bid}.conv1d", # mamba-hf
683684
"backbone.layers.{bid}.mixer.conv1d", # mamba
684685
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
685686
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
687+
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
686688
),
687689

688690
MODEL_TENSOR.SSM_X: (
@@ -697,6 +699,7 @@ class TensorNameMap:
697699
"backbone.layers.{bid}.mixer.dt_proj", # mamba
698700
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
699701
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
702+
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
700703
),
701704

702705
MODEL_TENSOR.SSM_DT_NORM: (
@@ -709,6 +712,7 @@ class TensorNameMap:
709712
"backbone.layers.{bid}.mixer.A_log", # mamba
710713
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
711714
"model.layers.layers.{bid}.mixer.A_log", # plamo2
715+
"model.layers.{bid}.linear_attn.A_log", # qwen3next
712716
),
713717

714718
MODEL_TENSOR.SSM_B_NORM: (
@@ -731,17 +735,23 @@ class TensorNameMap:
731735
),
732736

733737
MODEL_TENSOR.SSM_NORM: (
734-
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
735-
"backbone.layers.{bid}.mixer.norm", # mamba2
738+
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
739+
"model.layers.{bid}.linear_attn.norm", # qwen3next
740+
"backbone.layers.{bid}.mixer.norm", # mamba2
736741
),
737742

738743
MODEL_TENSOR.SSM_OUT: (
739744
"model.layers.{bid}.out_proj", # mamba-hf
740745
"backbone.layers.{bid}.mixer.out_proj", # mamba
741746
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
747+
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
742748
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
743749
),
744750

751+
MODEL_TENSOR.SSM_BETA_ALPHA: (
752+
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
753+
),
754+
745755
MODEL_TENSOR.TIME_MIX_W0: (
746756
"model.layers.{bid}.attention.w0", # rwkv7
747757
),

src/llama-arch.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3232
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
3333
{ LLM_ARCH_QWEN3, "qwen3" },
3434
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
35+
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
3536
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
3637
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
3738
{ LLM_ARCH_PHI2, "phi2" },
@@ -816,6 +817,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
816817
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
817818
},
818819
},
820+
{
821+
LLM_ARCH_QWEN3NEXT,
822+
{
823+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
824+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
825+
{ LLM_TENSOR_OUTPUT, "output" },
826+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
827+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
828+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
829+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
830+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
831+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
832+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
833+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
834+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
835+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
836+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
837+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
838+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
839+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
840+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
841+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
842+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
843+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
844+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
845+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
846+
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
847+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
848+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
849+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
850+
},
851+
},
819852
{
820853
LLM_ARCH_QWEN3VL,
821854
{

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum llm_arch {
3636
LLM_ARCH_QWEN2VL,
3737
LLM_ARCH_QWEN3,
3838
LLM_ARCH_QWEN3MOE,
39+
LLM_ARCH_QWEN3NEXT,
3940
LLM_ARCH_QWEN3VL,
4041
LLM_ARCH_QWEN3VLMOE,
4142
LLM_ARCH_PHI2,
@@ -368,6 +369,7 @@ enum llm_tensor {
368369
LLM_TENSOR_SSM_D,
369370
LLM_TENSOR_SSM_NORM,
370371
LLM_TENSOR_SSM_OUT,
372+
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
371373
LLM_TENSOR_TIME_MIX_W0,
372374
LLM_TENSOR_TIME_MIX_W1,
373375
LLM_TENSOR_TIME_MIX_W2,

src/llama-context.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "llama-context.h"
22

3+
#include "llama-arch.h"
34
#include "llama-impl.h"
45
#include "llama-batch.h"
56
#include "llama-io.h"
@@ -1386,6 +1387,9 @@ void llama_context::output_reorder() {
13861387
//
13871388

13881389
uint32_t llama_context::graph_max_nodes() const {
1390+
if (model.arch == LLM_ARCH_QWEN3NEXT) {
1391+
return std::max<uint32_t>(8192, 32u*model.n_tensors());
1392+
}
13891393
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
13901394
}
13911395

src/llama-hparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
// bump if necessary
88
#define LLAMA_MAX_LAYERS 512
9-
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
9+
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
1010

1111
enum llama_expert_gating_func_type {
1212
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,

0 commit comments

Comments
 (0)