|
|
@@ -10854,6 +10854,7 @@ static void print_debug_info(float * data, size_t size, const char * name, int64
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
+// chunked version of delta_net (for prompt processing)
|
|
|
void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
const struct ggml_tensor * src0 = dst->src[0]; // q (already normalized and scaled)
|
|
|
const struct ggml_tensor * src1 = dst->src[1]; // k (already normalized)
|
|
|
@@ -10870,13 +10871,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
|
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
|
const int64_t n_seqs = src0->ne[3]; // q tensor has n_seqs in dim 3
|
|
|
-
|
|
|
+ // Calculate chunk size
|
|
|
+ const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
|
+ const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
|
+ const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
|
|
|
|
|
|
// Add assertions to verify tensor dimensions
|
|
|
- GGML_ASSERT(src0->ne[3] == n_seqs); // q tensor
|
|
|
- GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
|
|
|
- GGML_ASSERT(src2->ne[3] == n_seqs); // v tensor
|
|
|
- GGML_ASSERT(src3->ne[3] == n_seqs); // g tensor
|
|
|
+ GGML_ASSERT(src0->ne[3] == n_seqs && src0->ne[2] == num_chunks * H_v); // q tensor
|
|
|
+ GGML_ASSERT(src1->ne[3] == n_seqs && src1->ne[2] == num_chunks * H_v); // k tensor
|
|
|
+ GGML_ASSERT(src2->ne[3] == n_seqs && src2->ne[2] == num_chunks * H_v); // v tensor
|
|
|
+ GGML_ASSERT(src3->ne[3] == n_seqs && src3->ne[2] == num_chunks * H_v); // g tensor
|
|
|
+ GGML_ASSERT(src5->ne[3] == n_seqs && src5->ne[2] == num_chunks * H_v); // decay mask tensor
|
|
|
+ GGML_ASSERT(src6->ne[3] == n_seqs && src6->ne[2] == num_chunks * H_v); // v_beta tensor
|
|
|
+ GGML_ASSERT(src7->ne[3] == n_seqs && src7->ne[2] == num_chunks * H_v); // k_beta tensor
|
|
|
+
|
|
|
GGML_ASSERT(src4->ne[3] == n_seqs); // state tensor
|
|
|
|
|
|
float * dst_data = (float *) dst->data;
|
|
|
@@ -10894,11 +10902,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- // Calculate chunk size
|
|
|
- const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
|
- const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
|
- const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
|
|
|
-
|
|
|
float * state_data = (float *) src4->data;
|
|
|
|
|
|
// Init new state with initial state (will probably be zeroes)
|
|
|
@@ -10970,14 +10973,14 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// Process each chunk with all sequences and heads together
|
|
|
for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
|
|
|
// Create lambdas for tensor access similar to recurrent function
|
|
|
- const auto q_chunk = [chunk, src0](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
|
|
|
- return ggml_get_f32_nd(src0, i, chunk * chunk_size + token_idx, head, seq);
|
|
|
+ const auto q_chunk = [chunk, src0, num_chunks](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
|
|
|
+ return ggml_get_f32_nd(src0, i, token_idx, head * num_chunks + chunk, seq);
|
|
|
};
|
|
|
- const auto k_chunk = [chunk, src1](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
|
|
|
- return ggml_get_f32_nd(src1, i, chunk * chunk_size + token_idx, head, seq);
|
|
|
+ const auto k_chunk = [chunk, src1, num_chunks](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
|
|
|
+ return ggml_get_f32_nd(src1, i, token_idx, head * num_chunks + chunk, seq);
|
|
|
};
|
|
|
- const auto g_chunk = [chunk, src3](int64_t seq, int64_t head, int64_t token_idx) {
|
|
|
- return ggml_get_f32_nd(src3, chunk * chunk_size + token_idx, 0, head, seq);
|
|
|
+ const auto g_chunk = [chunk, src3, num_chunks](int64_t seq, int64_t head, int64_t token_idx) {
|
|
|
+ return ggml_get_f32_nd(src3, token_idx, 0, head * num_chunks + chunk, seq);
|
|
|
};
|
|
|
|
|
|
// Allocate per-chunk arrays containing all sequences and heads
|
|
|
@@ -10999,7 +11002,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
- float * g_ptr = g + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
+ float * g_ptr = g + (chunk_size * H_v * num_chunks) * seq + chunk_size * (head * num_chunks + chunk);
|
|
|
+
|
|
|
float * q_g_exp_ptr = q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
// Fill q, k, decay_mask, and g data
|
|
|
@@ -11030,7 +11034,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// Now compute attention for all sequences and heads together
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
|
|
|
const float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
const float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
@@ -11046,13 +11050,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "q_k_trans", chunk);
|
|
|
|
|
|
-
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
- const float * decay_mask_ptr = (float *) src5->data + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
|
|
|
+ const float * decay_mask_ptr = (float *) src5->data + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
|
|
|
float attn_val = attn_ptr[i * chunk_size + j] * decay_mask_ptr[i * chunk_size + j];
|
|
|
// Apply simple causal attention mask (upper triangular with diagonal=1)
|
|
|
// This corresponds to: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
|
|
|
@@ -11065,7 +11068,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "attn_step4_new_chunk", chunk);
|
|
|
+ print_debug_info(attn, chunk_size * chunk_size * H_v * num_chunks * n_seqs, "attn_step4_new_chunk", chunk);
|
|
|
|
|
|
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
// k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
|
|
|
@@ -11084,7 +11087,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// Use regular matrix multiplication for attn @ v_new
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- const float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
+ const float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
|
|
|
const float * v_new_ptr = v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
float * attn_v_new_ptr = attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|