Просмотр исходного кода

Proper handling for n_tokens > GGML_DELTA_NET_CHUNK

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
a2c7b6794e
1 измененных файлов с 10 добавлено и 6 удалено
  1. 10 6
      src/models/llm_build_qwen3next.cpp

+ 10 - 6
src/models/llm_build_qwen3next.cpp

@@ -252,24 +252,28 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
     GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[1] == H_k && g->ne[2] == 1 && g->ne[3] == n_seqs);
 
-    ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, num_chunks * GGML_DELTA_NET_CHUNK, H_k, n_seqs);
-    ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, num_chunks * GGML_DELTA_NET_CHUNK, H_k, n_seqs);
+    ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
+    ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
     cb(beta_unsq, "beta_unsq", il);
     cb(beta_bcast, "beta_bcast", il);
 
     struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta_bcast);
+    v_beta = ggml_reshape_4d(ctx, v_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
     cb(v_beta, "v_beta", il);
     struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta_bcast);
+    k_beta = ggml_reshape_4d(ctx, k_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
     cb(k_beta, "k_beta", il);
+    k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
+    cb(k_beta, "k_reshape", il);
     struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
     cb(g_cumsum, "g_cumsum", il);
         
-    struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, num_chunks * GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs);  // [chunk_size, 1, n_tokens, n_seqs]
-    struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [1, chunk_size, n_tokens, n_seqs]
+    struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, GGML_DELTA_NET_CHUNK, 1, num_chunks * H_v, n_seqs);  // [chunk_size, 1, n_tokens, n_seqs]
+    struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs);  // [1, chunk_size, n_tokens, n_seqs]
     
     // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
-    struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
-    struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
     
     struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast); 
     cb(decay_mask, "sub", il);