|
|
@@ -112,6 +112,7 @@ const char * llm_type_name(llm_type type) {
|
|
|
case LLM_TYPE_A13B: return "A13B";
|
|
|
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
|
|
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
|
|
+ case LLM_TYPE_80B_A3B: return "80B.A3B";
|
|
|
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
|
|
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
|
|
case LLM_TYPE_300B_A47B: return "300B.A47B";
|
|
|
@@ -1809,6 +1810,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
// For Granite MoE Shared
|
|
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
|
|
|
} break;
|
|
|
+ case LLM_ARCH_QWEN3NEXT:
|
|
|
+ {
|
|
|
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
|
|
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
+
|
|
|
+ // Load linear attention (gated delta net) parameters
|
|
|
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
|
|
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
|
|
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
|
|
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
|
|
+ ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
|
|
+
|
|
|
+ // Mark recurrent layers (linear attention layers)
|
|
|
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
|
|
+ hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
|
|
|
+ }
|
|
|
+
|
|
|
+ switch (hparams.n_layer) {
|
|
|
+ case 80: type = LLM_TYPE_80B_A3B; break;
|
|
|
+ default: type = LLM_TYPE_UNKNOWN;
|
|
|
+ }
|
|
|
+ } break;
|
|
|
case LLM_ARCH_CHAMELEON:
|
|
|
{
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
@@ -2360,6 +2384,76 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
|
}
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_QWEN3NEXT:
|
|
|
+ {
|
|
|
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
|
|
+
|
|
|
+ // output
|
|
|
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
|
|
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
|
|
+
|
|
|
+ // if output is NULL, init from the input tok embed
|
|
|
+ if (output == NULL) {
|
|
|
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
|
|
+ }
|
|
|
+
|
|
|
+ const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
|
|
+
|
|
|
+ // Calculate dimensions from hyperparameters
|
|
|
+ const int64_t head_k_dim = hparams.ssm_d_state;
|
|
|
+ const int64_t head_v_dim = hparams.ssm_d_state;
|
|
|
+ const int64_t n_k_heads = hparams.ssm_n_group;
|
|
|
+ const int64_t n_v_heads = hparams.ssm_dt_rank;
|
|
|
+ const int64_t key_dim = head_k_dim * n_k_heads;
|
|
|
+ const int64_t value_dim = head_v_dim * n_v_heads;
|
|
|
+ const int64_t conv_dim = key_dim * 2 + value_dim;
|
|
|
+
|
|
|
+ // Calculate projection sizes
|
|
|
+ const int64_t qkvz_projection_size = key_dim * 2 + value_dim * 2;
|
|
|
+ const int64_t ba_projection_size = n_v_heads * 2;
|
|
|
+
|
|
|
+ for (int i = 0; i < n_layer; ++i) {
|
|
|
+ auto & layer = layers[i];
|
|
|
+
|
|
|
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
|
|
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
|
|
+
|
|
|
+ if ((i + 1) % 4 == 0) { // TODO: magic 4
|
|
|
+ // Attention layers
|
|
|
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_ff }, 0);
|
|
|
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
|
|
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
|
|
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
|
|
+
|
|
|
+ // Q/K normalization for attention layers
|
|
|
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
|
|
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
|
|
+
|
|
|
+ } else {
|
|
|
+ // Linear attention (gated delta net) specific tensors
|
|
|
+ // Create tensors with calculated dimensions
|
|
|
+ layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_projection_size }, 0);
|
|
|
+ layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
|
|
+ layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
|
|
+ layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
|
|
|
+ layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0);
|
|
|
+ layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
|
|
+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0);
|
|
|
+ }
|
|
|
+
|
|
|
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
|
|
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
|
|
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
|
|
|
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
|
|
+
|
|
|
+ // Shared experts
|
|
|
+ layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
|
|
|
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
|
|
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
|
|
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ break;
|
|
|
case LLM_ARCH_LLADA:
|
|
|
{
|
|
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
|
|
@@ -6075,7 +6169,8 @@ void llama_model::print_info() const {
|
|
|
arch == LLM_ARCH_FALCON_H1 ||
|
|
|
arch == LLM_ARCH_PLAMO2 ||
|
|
|
arch == LLM_ARCH_GRANITE_HYBRID ||
|
|
|
- arch == LLM_ARCH_NEMOTRON_H) {
|
|
|
+ arch == LLM_ARCH_NEMOTRON_H ||
|
|
|
+ arch == LLM_ARCH_QWEN3NEXT) {
|
|
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
|
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
|
|
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
|
|
@@ -18827,6 +18922,329 @@ struct llm_build_smallthinker : public llm_graph_context{
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+struct llm_build_qwen3next : public llm_graph_context_mamba {
|
|
|
+ llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
|
|
|
+ const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
+
|
|
|
+ ggml_tensor * cur;
|
|
|
+ ggml_tensor * inpL;
|
|
|
+
|
|
|
+ inpL = build_inp_embd(model.tok_embd);
|
|
|
+
|
|
|
+ auto * inp = build_inp_mem_hybrid();
|
|
|
+
|
|
|
+ ggml_tensor * inp_pos = build_inp_pos();
|
|
|
+
|
|
|
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
+
|
|
|
+ for (int il = 0; il < n_layer; ++il) {
|
|
|
+ struct ggml_tensor * inpSA = inpL;
|
|
|
+
|
|
|
+ // Pre-norm for attention/linear attention
|
|
|
+ cur = build_norm(inpL,
|
|
|
+ model.layers[il].attn_norm, NULL,
|
|
|
+ LLM_NORM_RMS, il);
|
|
|
+ cb(cur, "attn_norm", il);
|
|
|
+
|
|
|
+ // Determine layer type and build appropriate attention mechanism
|
|
|
+ if (hparams.is_recurrent(il)) {
|
|
|
+ // Linear attention layer (gated delta net)
|
|
|
+ cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
|
|
|
+ } else {
|
|
|
+ // Full attention layer
|
|
|
+ cur = build_qwen3next_attention_layer(
|
|
|
+ cur, inp_pos, inp->get_attn(), model,
|
|
|
+ n_embd_head, il);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Post-attention norm
|
|
|
+ cur = build_norm(cur,
|
|
|
+ model.layers[il].attn_post_norm, NULL,
|
|
|
+ LLM_NORM_RMS, il);
|
|
|
+ cb(cur, "attn_post_norm", il);
|
|
|
+
|
|
|
+ if (il == n_layer - 1 && inp_out_ids) {
|
|
|
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Residual connection
|
|
|
+ cur = ggml_add(ctx0, cur, inpSA);
|
|
|
+ cb(cur, "attn_residual", il);
|
|
|
+
|
|
|
+ // FFN layer (MoE or dense)
|
|
|
+ cur = build_layer_ffn(cur, model, il);
|
|
|
+
|
|
|
+ // Input for next layer
|
|
|
+ inpL = cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = inpL;
|
|
|
+
|
|
|
+ // Final norm
|
|
|
+ cur = build_norm(cur,
|
|
|
+ model.output_norm, NULL,
|
|
|
+ LLM_NORM_RMS, -1);
|
|
|
+
|
|
|
+ cb(cur, "result_norm", -1);
|
|
|
+ res->t_embd = cur;
|
|
|
+
|
|
|
+ // LM head
|
|
|
+ cur = build_lora_mm(model.output, cur);
|
|
|
+
|
|
|
+ cb(cur, "result_output", -1);
|
|
|
+ res->t_logits = cur;
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, cur);
|
|
|
+ }
|
|
|
+
|
|
|
+private:
|
|
|
+ ggml_tensor * build_qwen3next_attention_layer(
|
|
|
+ ggml_tensor * cur,
|
|
|
+ ggml_tensor * inp_pos,
|
|
|
+ llm_graph_input_attn_kv * inp_attn,
|
|
|
+ const llama_model & model,
|
|
|
+ const int64_t n_embd_head,
|
|
|
+ const int il) {
|
|
|
+
|
|
|
+ // QKV projection with gating
|
|
|
+ ggml_tensor * qkv_g = build_lora_mm(model.layers[il].wq, cur);
|
|
|
+ cb(qkv_g, "qkv_g", il);
|
|
|
+
|
|
|
+ // Split into Q and gate
|
|
|
+ const int64_t n_embd_q = hparams.n_head(il) * n_embd_head;
|
|
|
+ ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens,
|
|
|
+ n_embd_head * sizeof(float), qkv_g->nb[1], 0);
|
|
|
+ ggml_tensor * gate = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens,
|
|
|
+ n_embd_head * sizeof(float), qkv_g->nb[1], n_embd_q * ggml_element_size(qkv_g));
|
|
|
+
|
|
|
+ // K and V projections
|
|
|
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
|
|
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+
|
|
|
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
|
|
|
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
|
|
|
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
|
|
|
+
|
|
|
+ // Apply Q/K normalization
|
|
|
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+
|
|
|
+ // Apply RoPE
|
|
|
+ Qcur = ggml_rope_ext(
|
|
|
+ ctx0, Qcur, inp_pos, nullptr,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
+
|
|
|
+ Kcur = ggml_rope_ext(
|
|
|
+ ctx0, Kcur, inp_pos, nullptr,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
+
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+
|
|
|
+ // Attention computation
|
|
|
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
+ cur = build_attn(inp_attn,
|
|
|
+ model.layers[il].wo, nullptr,
|
|
|
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
|
|
+
|
|
|
+ // Apply gating
|
|
|
+ gate = ggml_reshape_2d(ctx0, gate, n_embd_q, n_tokens);
|
|
|
+ cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
|
|
|
+ cb(cur, "attn_gated", il);
|
|
|
+
|
|
|
+ return cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_tensor * build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
|
|
|
+ ggml_tensor * cur,
|
|
|
+ const llama_model & model,
|
|
|
+ const llama_ubatch & ubatch,
|
|
|
+ int il) {
|
|
|
+ // Gated Delta Net implementation using the new ggml_delta_net function
|
|
|
+ const auto * mctx_cur = inp->mctx;
|
|
|
+ const auto kv_head = mctx_cur->get_head();
|
|
|
+
|
|
|
+ const int64_t d_inner = hparams.ssm_d_inner;
|
|
|
+ const int64_t d_state = hparams.ssm_d_state;
|
|
|
+ const int64_t n_heads = hparams.ssm_dt_rank;
|
|
|
+ const int64_t head_dim = d_inner / n_heads;
|
|
|
+ const int64_t n_seqs = ubatch.n_seqs;
|
|
|
+
|
|
|
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
+
|
|
|
+ GGML_ASSERT(n_seqs != 0);
|
|
|
+ GGML_ASSERT(ubatch.equal_seqs());
|
|
|
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
|
|
+
|
|
|
+ // Input projection for QKV and beta/alpha
|
|
|
+ ggml_tensor * qkvz_ba = build_lora_mm(model.layers[il].ssm_in, cur);
|
|
|
+ cb(qkvz_ba, "linear_attn_in_proj", il);
|
|
|
+
|
|
|
+ // Split into QKV and beta/alpha components
|
|
|
+ const int64_t qkv_size = d_inner * 2 + d_state * 2;
|
|
|
+
|
|
|
+ ggml_tensor * qkv =
|
|
|
+ ggml_view_3d(ctx0, qkvz_ba, qkv_size, n_tokens, 1, qkv_size * sizeof(float), qkvz_ba->nb[1], 0);
|
|
|
+ ggml_tensor * ba = ggml_view_2d(ctx0, qkvz_ba, n_embd, n_tokens,
|
|
|
+ qkvz_ba->nb[1], qkv_size * sizeof(float));
|
|
|
+
|
|
|
+ // Reshape QKV for processing
|
|
|
+ qkv = ggml_reshape_3d(ctx0, qkv, head_dim, n_heads * 2 + d_state * 2 / head_dim, n_tokens);
|
|
|
+
|
|
|
+ // Split into individual components
|
|
|
+ ggml_tensor * query =
|
|
|
+ ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], 0);
|
|
|
+ ggml_tensor * key = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1],
|
|
|
+ n_heads * head_dim * sizeof(float));
|
|
|
+ ggml_tensor * value = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1],
|
|
|
+ n_heads * head_dim * 2 * sizeof(float));
|
|
|
+
|
|
|
+ // Process beta and alpha parameters (corrected dimensions)
|
|
|
+ ggml_tensor * beta_alpha = build_lora_mm(model.layers[il].ssm_beta_alpha, ba);
|
|
|
+ ggml_tensor * beta =
|
|
|
+ ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float), beta_alpha->nb[1], 0);
|
|
|
+ ggml_tensor * alpha = ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float),
|
|
|
+ beta_alpha->nb[1], n_heads * sizeof(float));
|
|
|
+
|
|
|
+ // Apply sigmoid to beta (exactly like reference: beta = b.sigmoid())
|
|
|
+ beta = ggml_sigmoid(ctx0, beta);
|
|
|
+
|
|
|
+ ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias
|
|
|
+ ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
|
|
|
+ ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor
|
|
|
+ one_tensor = ggml_exp(ctx0, one_tensor); // e^0 = 1
|
|
|
+ ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias)
|
|
|
+ ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
|
|
|
+ ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
|
|
|
+ ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
|
|
|
+ ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus)
|
|
|
+
|
|
|
+ // Get convolution weights and bias
|
|
|
+ ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
|
|
|
+ ggml_tensor * conv_bias = nullptr; // Add if your model has conv bias
|
|
|
+
|
|
|
+ // Get recurrent states (conv_states not needed as it's handled internally by ggml_delta_net)
|
|
|
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
|
|
+
|
|
|
+ // Reshape tensors to match ggml_delta_net expectations
|
|
|
+ // [S, H, n_tokens, n_seqs] format
|
|
|
+ query = ggml_reshape_4d(ctx0, query, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
+ key = ggml_reshape_4d(ctx0, key, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
+ value = ggml_reshape_4d(ctx0, value, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
+
|
|
|
+ // Beta tensor
|
|
|
+ beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
|
|
|
+
|
|
|
+ // Get current state slice
|
|
|
+ ggml_tensor * state = ggml_view_4d(ctx0, ssm_states_all, head_dim, head_dim, n_heads, n_seqs,
|
|
|
+ ssm_states_all->nb[0], ssm_states_all->nb[1], ssm_states_all->nb[2],
|
|
|
+ kv_head * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all));
|
|
|
+ state = ggml_cont(ctx0, state);
|
|
|
+ gate = ggml_repeat(ctx0, gate, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, n_heads, n_tokens, n_seqs));
|
|
|
+
|
|
|
+ // Call the new ggml_delta_net function
|
|
|
+ ggml_tensor * output = ggml_delta_net(ctx0,
|
|
|
+ key, // k tensor
|
|
|
+ value, // v tensor
|
|
|
+ query, // q tensor
|
|
|
+ gate, // g tensor
|
|
|
+ conv_weight, // conv_weight tensor
|
|
|
+ conv_bias, // conv_bias tensor (can be nullptr)
|
|
|
+ beta, // beta tensor
|
|
|
+ state, // state tensor
|
|
|
+ 64, // chunk_size (adjust as needed)
|
|
|
+ true, // use_qk_l2norm
|
|
|
+ 1.0f // scale (adjust based on your model)
|
|
|
+ );
|
|
|
+ cb(output, "delta_net_output", il);
|
|
|
+
|
|
|
+ // Extract the output part (first half of the concatenated result)
|
|
|
+ ggml_tensor * attn_out = ggml_view_4d(ctx0, output, head_dim, n_heads, n_tokens, n_seqs, output->nb[0],
|
|
|
+ output->nb[1], output->nb[2], 0);
|
|
|
+
|
|
|
+ // Extract the new state (second half of the concatenated result)
|
|
|
+ ggml_tensor * new_state =
|
|
|
+ ggml_view_4d(ctx0, output, head_dim, head_dim, n_heads, n_seqs, output->nb[0], output->nb[1], output->nb[2],
|
|
|
+ n_tokens * head_dim * n_heads * sizeof(float));
|
|
|
+
|
|
|
+ // Update the recurrent states
|
|
|
+ ggml_build_forward_expand(
|
|
|
+ gf, ggml_cpy(ctx0, new_state,
|
|
|
+ ggml_view_1d(
|
|
|
+ ctx0, ssm_states_all, head_dim * head_dim * n_heads * n_seqs,
|
|
|
+ kv_head * n_seqs * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all))));
|
|
|
+
|
|
|
+ // Apply normalization and gating
|
|
|
+ attn_out = build_norm(attn_out, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+
|
|
|
+ // Output projection
|
|
|
+ cur = build_lora_mm(model.layers[il].wo, attn_out);
|
|
|
+ cb(cur, "linear_attn_out", il);
|
|
|
+
|
|
|
+ // Reshape back to original dimensions
|
|
|
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
|
|
+
|
|
|
+ return cur;
|
|
|
+ }
|
|
|
+ ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
|
|
|
+ // Check if this is an MoE layer
|
|
|
+ if (model.layers[il].ffn_gate_inp != nullptr) {
|
|
|
+ // MoE branch
|
|
|
+ ggml_tensor * moe_out = build_moe_ffn(cur,
|
|
|
+ model.layers[il].ffn_gate_inp,
|
|
|
+ model.layers[il].ffn_up_exps,
|
|
|
+ model.layers[il].ffn_gate_exps,
|
|
|
+ model.layers[il].ffn_down_exps,
|
|
|
+ nullptr,
|
|
|
+ n_expert, n_expert_used,
|
|
|
+ LLM_FFN_SILU, true,
|
|
|
+ false, 0.0,
|
|
|
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
|
+ il);
|
|
|
+ cb(moe_out, "ffn_moe_out", il);
|
|
|
+
|
|
|
+ // Add shared experts if present
|
|
|
+ if (model.layers[il].ffn_up_shexp != nullptr) {
|
|
|
+ ggml_tensor * ffn_shexp = build_ffn(cur,
|
|
|
+ model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
|
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
|
+ NULL,
|
|
|
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
+ cb(ffn_shexp, "ffn_shexp", il);
|
|
|
+
|
|
|
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
+ } else {
|
|
|
+ cur = moe_out;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // Dense FFN branch
|
|
|
+ cur = build_ffn(cur,
|
|
|
+ model.layers[il].ffn_up, NULL, NULL,
|
|
|
+ model.layers[il].ffn_gate, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down, NULL, NULL,
|
|
|
+ NULL,
|
|
|
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Residual connection
|
|
|
+ cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
|
|
|
+ cb(cur, "ffn_residual", il);
|
|
|
+
|
|
|
+ return cur;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+
|
|
|
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
|
|
llama_memory_i * res;
|
|
|
|
|
|
@@ -19349,6 +19767,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|
|
llm = std::make_unique<llm_build_smallthinker<false>>(*this, params);
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_QWEN3NEXT:
|
|
|
+ {
|
|
|
+ llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
|
|
+ } break;
|
|
|
default:
|
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
|
@@ -19524,6 +19946,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
|
case LLM_ARCH_QWEN2MOE:
|
|
|
case LLM_ARCH_QWEN3:
|
|
|
case LLM_ARCH_QWEN3MOE:
|
|
|
+ case LLM_ARCH_QWEN3NEXT:
|
|
|
case LLM_ARCH_LLADA_MOE:
|
|
|
case LLM_ARCH_OLMO2:
|
|
|
case LLM_ARCH_OLMOE:
|