Bladeren bron

Add proper check for previous state

Piotr Wilkin 2 maanden geleden
bovenliggende
commit
2fdbf16eb1
3 gewijzigde bestanden met toevoegingen van 7 en 2 verwijderingen
  1. 4 0
      src/llama-memory-recurrent.cpp
  2. 1 0
      src/llama-memory-recurrent.h
  3. 2 2
      src/models/llm_build_qwen3next.cpp

+ 4 - 0
src/llama-memory-recurrent.cpp

@@ -1144,3 +1144,7 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
 int32_t llama_memory_recurrent_context::s_copy(int i) const {
     return  mem->cells[i + mem->head].src0;
 }
+
+bool llama_memory_recurrent_context::has_previous_state() const {
+    return mem->cells[mem->head].pos >= 0;
+}

+ 1 - 0
src/llama-memory-recurrent.h

@@ -160,6 +160,7 @@ public:
     ggml_tensor * get_s_l(int32_t il) const;
 
     int32_t s_copy(int i) const;
+    bool has_previous_state() const;
 
 private:
     const llama_memory_status status;

+ 2 - 2
src/models/llm_build_qwen3next.cpp

@@ -605,7 +605,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
-    bool is_generation = mctx_cur->get_rs_z() < 0;
+    bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
 
     // Build the convolution states tensor
     ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
@@ -719,7 +719,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     // Choose between delta_net and delta_net_recurrent based on generation mode
     ggml_tensor * attn_out;
-    if (is_generation) {
+    if (use_precomputed_states) {
         // Use delta_net_recurrent for single token generation
         attn_out = delta_net_recurrent(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
     } else {