|
@@ -10590,6 +10590,51 @@ static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, con
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// Helper function to apply triangular updates to attention matrix
|
|
|
|
|
+static void delta_net_compute_diagonal_updates(
|
|
|
|
|
+ float * attn,
|
|
|
|
|
+ const int64_t chunk_size,
|
|
|
|
|
+ const int64_t n_heads,
|
|
|
|
|
+ const int64_t n_seqs) {
|
|
|
|
|
+
|
|
|
|
|
+ // Apply triangular updates like in the Python reference:
|
|
|
|
|
+ // for i in range(1, chunk_size):
|
|
|
|
|
+ // row = attn[..., i, :i].clone()
|
|
|
|
|
+ // sub = attn[..., :i, :i].clone()
|
|
|
|
|
+ // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
+ // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t head = 0; head < n_heads; head++) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ // Get pointer to this head's attention matrix
|
|
|
|
|
+ float * attn_head = attn + (head * chunk_size * chunk_size * n_seqs) + (seq * chunk_size * chunk_size);
|
|
|
|
|
+
|
|
|
|
|
+ // Create temporary storage for the original values to avoid in-place modification
|
|
|
|
|
+ float * attn_copy = (float *) malloc(chunk_size * chunk_size * sizeof(float));
|
|
|
|
|
+ memcpy(attn_copy, attn_head, chunk_size * chunk_size * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Apply triangular updates using the original values
|
|
|
|
|
+ for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ // Compute (row.unsqueeze(-1) * sub).sum(-2) using original values
|
|
|
|
|
+ for (int64_t k = 0; k < i; k++) {
|
|
|
|
|
+ sum += attn_copy[i * chunk_size + k] * attn_copy[k * chunk_size + j];
|
|
|
|
|
+ }
|
|
|
|
|
+ attn_head[i * chunk_size + j] = attn_copy[i * chunk_size + j] + sum;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Add identity matrix
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ attn_head[i * chunk_size + i] += 1.0f;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ free(attn_copy);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
const struct ggml_tensor * src0 = dst->src[0]; // q (already normalized and scaled)
|
|
const struct ggml_tensor * src0 = dst->src[0]; // q (already normalized and scaled)
|
|
|
const struct ggml_tensor * src1 = dst->src[1]; // k (already normalized)
|
|
const struct ggml_tensor * src1 = dst->src[1]; // k (already normalized)
|
|
@@ -10614,11 +10659,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
|
|
GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
|
|
|
GGML_ASSERT(src2->ne[3] == n_seqs); // v tensor
|
|
GGML_ASSERT(src2->ne[3] == n_seqs); // v tensor
|
|
|
GGML_ASSERT(src3->ne[3] == n_seqs); // g tensor
|
|
GGML_ASSERT(src3->ne[3] == n_seqs); // g tensor
|
|
|
- GGML_ASSERT(src4->ne[3] == n_seqs); // beta tensor
|
|
|
|
|
|
|
+ GGML_ASSERT(src4->ne[3] == n_seqs); // state tensor
|
|
|
|
|
|
|
|
float * dst_data = (float *) dst->data;
|
|
float * dst_data = (float *) dst->data;
|
|
|
// Following GLA pattern: output is first part, state is second part
|
|
// Following GLA pattern: output is first part, state is second part
|
|
|
- float * output = dst_data; // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
|
|
|
|
|
|
|
+ float * output = dst_data; // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
|
|
|
float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
|
|
|
|
|
const int ith = params->ith;
|
|
const int ith = params->ith;
|
|
@@ -10633,414 +10678,378 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// Clear output and new state section
|
|
// Clear output and new state section
|
|
|
memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
|
|
memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
|
|
|
|
|
|
|
|
- // Get tensor data pointers
|
|
|
|
|
- float * state_data = (float *) src4->data;
|
|
|
|
|
- float * decay_mask = (float *) src5->data;
|
|
|
|
|
-
|
|
|
|
|
- // Allocate temporary buffers for computation
|
|
|
|
|
|
|
+ // Calculate chunk size
|
|
|
const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
|
- // The first dimension is the chunk_size, second is head_dim, third is num_heads, fourth is n_seqs
|
|
|
|
|
- // Note: In reference Python implementation, tensors are padded to multiple of chunk_size
|
|
|
|
|
- // but the output only contains the real sequence length, not the padded length
|
|
|
|
|
-
|
|
|
|
|
- // Calculate the actual padded sequence length for internal processing
|
|
|
|
|
- const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
|
|
|
- const int64_t total_sequence_length = n_tokens + pad_size;
|
|
|
|
|
- const int64_t n_chunks = (total_sequence_length + chunk_size - 1) / chunk_size; // Ceiling division
|
|
|
|
|
-
|
|
|
|
|
- // Temporary buffers for each chunk
|
|
|
|
|
- std::vector<float> attn(chunk_size * chunk_size, 0.0f);
|
|
|
|
|
- std::vector<float> value(chunk_size * S_v, 0.0f);
|
|
|
|
|
- std::vector<float> k_cumdecay(chunk_size * S_k, 0.0f);
|
|
|
|
|
- std::vector<double> g_exp(chunk_size, 0.0f);
|
|
|
|
|
- std::vector<float> g_cumsum(chunk_size, 0.0f);
|
|
|
|
|
- std::vector<float> last_state(S_v * S_v * H_v, 0.0f);
|
|
|
|
|
-
|
|
|
|
|
- // Initialize last_state with input state data
|
|
|
|
|
- // State format in GGML: [S_v, S_v * H_v, 1, 1] where S_v * H_v = S_v * num_heads
|
|
|
|
|
- // The state tensor has format [S_v, S_v * H_v, 1, 1] where second dimension is S_v * num_heads
|
|
|
|
|
- // For delta_net, S_k == S_v (both k and v have the same head dimension)
|
|
|
|
|
- for (int64_t h = 0; h < H_v; h++) {
|
|
|
|
|
- for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
|
|
- for (int64_t d2 = 0; d2 < S_v; d2++) {
|
|
|
|
|
- // GGML state index: [d1, d2 + h*S_v, 0, 0] in flattened form
|
|
|
|
|
- int64_t ggml_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
|
|
- // Our computed state index: [d1, d2 + h*S_v]
|
|
|
|
|
- int64_t computed_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
|
|
- last_state[computed_state_idx] = state_data[ggml_state_idx];
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Maintain running cumulative sum across all chunks
|
|
|
|
|
- std::vector<float> running_cumsum(n_tokens, 0.0f);
|
|
|
|
|
-
|
|
|
|
|
- // Process each chunk
|
|
|
|
|
- for (int64_t chunk_idx = 0; chunk_idx < n_chunks; chunk_idx++) {
|
|
|
|
|
- // Process each head and sequence
|
|
|
|
|
- for (int64_t h = 0; h < H_k; h++) {
|
|
|
|
|
- for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
- // Extract chunk data for this head and sequence
|
|
|
|
|
- std::vector<float> q_chunk(chunk_size * S_k);
|
|
|
|
|
- std::vector<float> k_chunk(chunk_size * S_k);
|
|
|
|
|
- std::vector<float> v_chunk(chunk_size * S_v);
|
|
|
|
|
- std::vector<float> v_beta_chunk(chunk_size * S_v);
|
|
|
|
|
- std::vector<float> k_beta_chunk(chunk_size * S_k);
|
|
|
|
|
- std::vector<float> g_chunk(chunk_size);
|
|
|
|
|
-
|
|
|
|
|
- // Initialize chunks with zeros for padding
|
|
|
|
|
- std::fill(q_chunk.begin(), q_chunk.end(), 0.0f);
|
|
|
|
|
- std::fill(k_chunk.begin(), k_chunk.end(), 0.0f);
|
|
|
|
|
- std::fill(v_chunk.begin(), v_chunk.end(), 0.0f);
|
|
|
|
|
- std::fill(v_beta_chunk.begin(), v_beta_chunk.end(), 0.0f);
|
|
|
|
|
- std::fill(k_beta_chunk.begin(), k_beta_chunk.end(), 0.0f);
|
|
|
|
|
- std::fill(g_chunk.begin(), g_chunk.end(), 0.0f);
|
|
|
|
|
-
|
|
|
|
|
- // Determine actual tokens in this chunk
|
|
|
|
|
- int64_t tokens_in_chunk = std::min(chunk_size, n_tokens - chunk_idx * chunk_size);
|
|
|
|
|
-
|
|
|
|
|
- // Copy data for this chunk
|
|
|
|
|
- for (int64_t t = 0; t < tokens_in_chunk; t++) {
|
|
|
|
|
- int64_t actual_pos = chunk_idx * chunk_size + t; // Position in the original sequence
|
|
|
|
|
-
|
|
|
|
|
- // Only copy if this position is within the original sequence length
|
|
|
|
|
- if (actual_pos < n_tokens) {
|
|
|
|
|
- // Calculate indices in GGML format [chunk_size, head_dim, num_heads, n_seqs]
|
|
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
|
|
- q_chunk[t * S_k + d] = ggml_get_f32_nd(src0, actual_pos, d, h, seq);
|
|
|
|
|
- k_chunk[t * S_k + d] = ggml_get_f32_nd(src1, actual_pos, d, h, seq);
|
|
|
|
|
- k_beta_chunk[t * S_k + d] = ggml_get_f32_nd(src7, actual_pos, d, h, seq);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
- v_chunk[t * S_v + d] = ggml_get_f32_nd(src2, actual_pos, d, h, seq);
|
|
|
|
|
- v_beta_chunk[t * S_v + d] = ggml_get_f32_nd(src6, actual_pos, d, h, seq);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (actual_pos <
|
|
|
|
|
- n_tokens) { // Only copy if this position is within the original sequence length
|
|
|
|
|
- // Use the safe GGML function to access tensor values
|
|
|
|
|
- g_chunk[t] = ggml_get_f32_nd(src3, actual_pos, 0, h, seq);
|
|
|
|
|
- } else {
|
|
|
|
|
- // For padded positions, set to 0 (or a default value)
|
|
|
|
|
- g_chunk[t] = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- // For padded positions beyond original sequence, set to 0
|
|
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
|
|
- q_chunk[t * S_k + d] = 0.0f;
|
|
|
|
|
- k_chunk[t * S_k + d] = 0.0f;
|
|
|
|
|
- k_beta_chunk[t * S_k + d] = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
- v_chunk[t * S_v + d] = 0.0f;
|
|
|
|
|
- v_beta_chunk[t * S_v + d] = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
- g_chunk[t] = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // In Python, cumsum is applied to each chunk separately after reshaping
|
|
|
|
|
- // So we need to compute cumsum within this chunk only
|
|
|
|
|
-
|
|
|
|
|
- // g_chunk already contains the cumsum values from src3 (g_cumsum), so use them directly
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only for actual tokens, not full chunk size
|
|
|
|
|
- g_cumsum[i] = g_chunk[i];
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // For padded positions, set cumsum values to 0
|
|
|
|
|
- for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
|
|
|
|
|
- g_cumsum[i] = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Compute g_exp from cumulative sums (like Python: g.cumsum().exp())
|
|
|
|
|
- // Apply numerical stability to prevent underflow for very negative values
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only for actual tokens, not full chunk size
|
|
|
|
|
- // Use double precision for exponential to avoid overflow/underflow
|
|
|
|
|
- // Apply lower bound to prevent extreme underflow - exp(-50) is about 1.9e-22
|
|
|
|
|
- double g_val = (double) g_cumsum[i];
|
|
|
|
|
- double g_exp_double = exp(g_val);
|
|
|
|
|
- g_exp[i] = g_exp_double;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // For padded positions, set exp values to 0
|
|
|
|
|
- for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
|
|
|
|
|
- g_exp[i] = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
- // Step 1: Compute k_beta @ key.T (this corresponds to the Python: k_beta @ key.transpose(-1, -2))
|
|
|
|
|
- // Only compute for actual tokens in chunk
|
|
|
|
|
- ggml_compute_k_beta_key_t_f32(k_beta_chunk.data(), k_chunk.data(), attn.data(), tokens_in_chunk,
|
|
|
|
|
- S_k); // Use actual tokens, not full chunk size
|
|
|
|
|
-
|
|
|
|
|
- // Apply precomputed decay mask from src5 and negate the result (like Python: -(...))
|
|
|
|
|
- // The decay mask is computed in ggml_delta_net in ggml.c and passed as src5
|
|
|
|
|
- // Apply the precomputed decay mask from src5 (decay_mask tensor)
|
|
|
|
|
- // The decay_mask tensor now contains exp(g_cumsum[j] - g_cumsum[i]) values
|
|
|
|
|
- // where g_cumsum[j] - g_cumsum[i] is computed in the main function
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only for actual tokens
|
|
|
|
|
- for (int64_t j = 0; j < tokens_in_chunk; j++) { // Only for actual tokens
|
|
|
|
|
- // Get decay mask value from precomputed tensor
|
|
|
|
|
- // src5 decay_mask has shape [chunk_size, chunk_size, H_k, n_seqs] in GGML format
|
|
|
|
|
- // Format: [i_pos, j_pos, head, seq] - represents exp(g_cumsum[j] - g_cumsum[i])
|
|
|
|
|
- float decay_val = ggml_get_f32_nd(
|
|
|
|
|
- src5, i, j, h, seq); // [i, j, h, seq] to get exp(g_cumsum[j] - g_cumsum[i]) for head h
|
|
|
|
|
- if (j <= i) { // Only apply to lower triangular part (i >= j)
|
|
|
|
|
- // The decay_val already contains exp(g_cumsum[j] - g_cumsum[i]), no need for additional exponential
|
|
|
|
|
- // Apply the decay mask and negate (like Python: -((k_beta @ key.T) * decay_mask))
|
|
|
|
|
- attn[i * chunk_size + j] = -attn[i * chunk_size + j] * decay_val;
|
|
|
|
|
- } else {
|
|
|
|
|
- attn[i * chunk_size + j] =
|
|
|
|
|
- 0.0f; // Zero out upper triangular part (like Python: masked_fill(mask, 0))
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
|
|
|
+ const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
|
|
|
|
|
+
|
|
|
|
|
+ // Apply triangular updates to the precomputed attention matrix
|
|
|
|
|
+ // This is the missing piece that was causing the attention matrix to have all zeros
|
|
|
|
|
+ float * attn_data = (float *) src8->data;
|
|
|
|
|
+ delta_net_compute_diagonal_updates(attn_data, chunk_size, H_v, n_seqs);
|
|
|
|
|
+
|
|
|
|
|
+ // Debug: Check attention matrix after triangular updates
|
|
|
|
|
+ float attn_after_updates_sum = 0.0f;
|
|
|
|
|
+ float attn_after_updates_max = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * chunk_size * H_v * n_seqs; i++) {
|
|
|
|
|
+ attn_after_updates_sum += attn_data[i];
|
|
|
|
|
+ attn_after_updates_max = fmaxf(attn_after_updates_max, fabsf(attn_data[i]));
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ attn_after_triangular_updates sum = %f, max = %f\n", attn_after_updates_sum, attn_after_updates_max);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute value = attn @ v_beta and k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
|
|
+ // These should be computed once before the chunk loop, like in the Python reference
|
|
|
|
|
+ printf("=== Computing value and k_cumdecay before chunk loop ===\n");
|
|
|
|
|
+
|
|
|
|
|
+ // Compute value and k_cumdecay for each head and sequence
|
|
|
|
|
+ for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
|
|
|
|
|
+ for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
|
|
|
|
|
+ // Get offsets for this head and sequence
|
|
|
|
|
+ const int64_t attn_offset = (seq_idx * src8->nb[3] + head_idx * src8->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t v_beta_offset = (seq_idx * src6->nb[3] + head_idx * src6->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t k_beta_offset = (seq_idx * src7->nb[3] + head_idx * src7->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t g_offset = (seq_idx * src3->nb[3] + head_idx * src3->nb[2]) / sizeof(float);
|
|
|
|
|
+
|
|
|
|
|
+ float * attn_precomputed = (float *) src8->data + attn_offset;
|
|
|
|
|
+ float * v_beta_ptr = (float *) src6->data + v_beta_offset;
|
|
|
|
|
+ float * k_beta_ptr = (float *) src7->data + k_beta_offset;
|
|
|
|
|
+ float * g_vals = (float *) src3->data + g_offset;
|
|
|
|
|
+
|
|
|
|
|
+ // Compute value = attn @ v_beta
|
|
|
|
|
+ float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ sum += attn_precomputed[i * chunk_size + j] * v_beta_ptr[j * S_v + d];
|
|
|
}
|
|
}
|
|
|
|
|
+ value[i * S_v + d] = sum;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Step 2: Apply triangular updates (equivalent to Python's complex triangular update)
|
|
|
|
|
- // Python: for i in range(1, chunk_size):
|
|
|
|
|
- // row = attn[..., i, :i].clone() // row = attn[i, 0:i]
|
|
|
|
|
- // sub = attn[..., :i, :i].clone() // sub = attn[0:i, 0:i]
|
|
|
|
|
- // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
- // This means: new_attn[i, j] = old_attn[i, j] + sum_k(old_attn[i, k] * old_attn[k, j]) for k < i
|
|
|
|
|
- for (int64_t i = 1; i < tokens_in_chunk; i++) {
|
|
|
|
|
- // Store the original row values to avoid using updated values in computation
|
|
|
|
|
- std::vector<float> original_row(i);
|
|
|
|
|
- for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
- original_row[j] = attn[i * tokens_in_chunk + j]; // Use tokens_in_chunk for indexing
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ float value_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ value_sum += value[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ PRE-CHUNK value_sum = %f (head %ld, seq %ld)\n", value_sum, head_idx, seq_idx);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
|
|
+ float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ float g_exp = expf(g_vals[j]);
|
|
|
|
|
+ sum += attn_precomputed[i * chunk_size + j] * k_beta_ptr[j * S_k + d] * g_exp;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
|
|
+ k_cumdecay[i * S_k + d] = sum;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ float k_cumdecay_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_k; i++) {
|
|
|
|
|
+ k_cumdecay_sum += k_cumdecay[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ PRE-CHUNK k_cumdecay_sum = %f (head %ld, seq %ld)\n", k_cumdecay_sum, head_idx, seq_idx);
|
|
|
|
|
+
|
|
|
|
|
+ free(value);
|
|
|
|
|
+ free(k_cumdecay);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("=== End pre-chunk computations ===\n");
|
|
|
|
|
+
|
|
|
|
|
+ // Initialize last_recurrent_state
|
|
|
|
|
+ // last_recurrent_state = torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
|
|
|
|
|
+ // if initial_state is None else initial_state.to(value)
|
|
|
|
|
+ float * initial_state_ptr = (float *) src4->data;
|
|
|
|
|
+
|
|
|
|
|
+ // If initial_state is provided, copy it to new_state, otherwise initialize new_state to zeros
|
|
|
|
|
+ // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
|
|
+ // This means: [n_heads * v_head_dim, v_head_dim * n_seqs, 1, 1]
|
|
|
|
|
+ // So total size is: S_v * H_v * S_v * n_seqs
|
|
|
|
|
+ if (initial_state_ptr != NULL) {
|
|
|
|
|
+ // Copy initial state to new state
|
|
|
|
|
+ memcpy(new_state, initial_state_ptr, S_v * H_v * S_v * n_seqs * sizeof(float));
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // Initialize new state to zeros
|
|
|
|
|
+ memset(new_state, 0, S_v * H_v * S_v * n_seqs * sizeof(float));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Process each chunk for the main computation
|
|
|
|
|
+ // Following the Python reference implementation exactly
|
|
|
|
|
+ for (int64_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) {
|
|
|
|
|
+ for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
|
|
|
|
|
+ // Process each head in this chunk
|
|
|
|
|
+ for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
|
|
|
|
|
+ // Get the recurrent state for this sequence and head
|
|
|
|
|
+ // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
|
|
+ // For each head and sequence, we need to find the correct state slice
|
|
|
|
|
+ // The state is organized as: [head_idx * S_v * S_v * n_seqs + seq_idx * S_v * S_v]
|
|
|
|
|
+ float * last_recurrent_state = new_state + (head_idx * S_v * S_v * n_seqs) + (seq_idx * S_v * S_v);
|
|
|
|
|
+ printf("\n=== C++ Processing chunk %ld, seq %ld, head %ld ===\n", chunk_idx, seq_idx, head_idx);
|
|
|
|
|
+ // Get pointers to current chunk data for this head
|
|
|
|
|
+ // GGML tensor layout: [S_k/S_v, chunk_size, H_v, n_seqs]
|
|
|
|
|
+ // Python layout: [batch_size, sequence_length, num_heads, k_head_dim]
|
|
|
|
|
+ // After transpose: [batch_size, num_heads, sequence_length, k_head_dim]
|
|
|
|
|
+
|
|
|
|
|
+ // For GGML: ne[0]=S_k/S_v, ne[1]=chunk_size, ne[2]=H_v, ne[3]=n_seqs
|
|
|
|
|
+ // nb[0]=sizeof(float)*S_k/S_v, nb[1]=sizeof(float)*S_k/S_v*chunk_size, etc.
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t q_offset = (seq_idx * src0->nb[3] + head_idx * src0->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t k_offset = (seq_idx * src1->nb[3] + head_idx * src1->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t v_offset = (seq_idx * src2->nb[3] + head_idx * src2->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t g_offset = (seq_idx * src3->nb[3] + head_idx * src3->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t v_beta_offset = (seq_idx * src6->nb[3] + head_idx * src6->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t k_beta_offset = (seq_idx * src7->nb[3] + head_idx * src7->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t attn_offset = (seq_idx * src8->nb[3] + head_idx * src8->nb[2]) / sizeof(float);
|
|
|
|
|
+ const int64_t decay_mask_offset = (seq_idx * src5->nb[3] + head_idx * src5->nb[2]) / sizeof(float);
|
|
|
|
|
+
|
|
|
|
|
+ // Calculate strides for each tensor
|
|
|
|
|
+ const int64_t q_stride0 = src0->nb[0] / sizeof(float); // S_k
|
|
|
|
|
+ const int64_t q_stride1 = src0->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t k_stride0 = src1->nb[0] / sizeof(float); // S_k
|
|
|
|
|
+ const int64_t k_stride1 = src1->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t v_stride0 = src2->nb[0] / sizeof(float); // S_v
|
|
|
|
|
+ const int64_t v_stride1 = src2->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t g_stride0 = src3->nb[0] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t v_beta_stride0 = src6->nb[0] / sizeof(float); // S_v
|
|
|
|
|
+ const int64_t v_beta_stride1 = src6->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t k_beta_stride0 = src7->nb[0] / sizeof(float); // S_k
|
|
|
|
|
+ const int64_t k_beta_stride1 = src7->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t attn_stride0 = src8->nb[0] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t attn_stride1 = src8->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t decay_mask_stride0 = src5->nb[0] / sizeof(float); // chunk_size
|
|
|
|
|
+ const int64_t decay_mask_stride1 = src5->nb[1] / sizeof(float); // chunk_size
|
|
|
|
|
+
|
|
|
|
|
+ // Get decay mask for this chunk and head
|
|
|
|
|
+ float * decay_mask = (float *) src5->data + decay_mask_offset;
|
|
|
|
|
+
|
|
|
|
|
+ // Use pre-computed attention matrix from src8 (after triangular updates)
|
|
|
|
|
+ // The Python reference computes triangular updates before the chunk loop
|
|
|
|
|
+ float * attn_precomputed = (float *) src8->data + attn_offset;
|
|
|
|
|
+
|
|
|
|
|
+ // Debug: print precomputed attention matrix values
|
|
|
|
|
+ float attn_precomputed_sum = 0.0f;
|
|
|
|
|
+ float attn_precomputed_max = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
|
|
|
|
|
+ attn_precomputed_sum += attn_precomputed[i];
|
|
|
|
|
+ attn_precomputed_max = fmaxf(attn_precomputed_max, fabsf(attn_precomputed[i]));
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ attn_precomputed_sum = %f, max = %f\n", attn_precomputed_sum, attn_precomputed_max);
|
|
|
|
|
+ printf("C++ attn_precomputed first 10 values: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n",
|
|
|
|
|
+ attn_precomputed[0], attn_precomputed[1], attn_precomputed[2], attn_precomputed[3], attn_precomputed[4],
|
|
|
|
|
+ attn_precomputed[5], attn_precomputed[6], attn_precomputed[7], attn_precomputed[8], attn_precomputed[9]);
|
|
|
|
|
+ printf("C++ attn_precomputed diagonal values: %f, %f, %f, %f, %f\n",
|
|
|
|
|
+ attn_precomputed[0], attn_precomputed[65], attn_precomputed[130], attn_precomputed[195], attn_precomputed[260]);
|
|
|
|
|
+
|
|
|
|
|
+ // Get g values for this chunk and head
|
|
|
|
|
+ float * g_vals = (float *) src3->data + g_offset;
|
|
|
|
|
+
|
|
|
|
|
+ // Get v_beta and k_beta for this chunk and head
|
|
|
|
|
+ float * v_beta_ptr = (float *) src6->data + v_beta_offset;
|
|
|
|
|
+ float * k_beta_ptr = (float *) src7->data + k_beta_offset;
|
|
|
|
|
+
|
|
|
|
|
+ // Debug: print v_beta and k_beta values
|
|
|
|
|
+ float v_beta_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ v_beta_sum += v_beta_ptr[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ v_beta_sum = %f\n", v_beta_sum);
|
|
|
|
|
+
|
|
|
|
|
+ float k_beta_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_k; i++) {
|
|
|
|
|
+ k_beta_sum += k_beta_ptr[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ k_beta_sum = %f\n", k_beta_sum);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute value = attn_precomputed @ v_beta
|
|
|
|
|
+ float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
float sum = 0.0f;
|
|
float sum = 0.0f;
|
|
|
- for (int64_t k = 0; k < i; k++) {
|
|
|
|
|
- // This implements: sum over k of (original_row[k] * sub[k, j])
|
|
|
|
|
- // Where sub[k, j] is attn[k, j] (the original value before updates)
|
|
|
|
|
- sum += original_row[k] * attn[k * tokens_in_chunk + j]; // Use tokens_in_chunk for indexing
|
|
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ sum += attn_precomputed[i * chunk_size + j] * v_beta_ptr[j * v_beta_stride0 + d * v_beta_stride1];
|
|
|
}
|
|
}
|
|
|
- // The new value is: original_value + matrix_mult_result
|
|
|
|
|
- attn[i * tokens_in_chunk + j] = original_row[j] + sum;
|
|
|
|
|
|
|
+ value[i * S_v + d] = sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Step 3: Add identity matrix (equivalent to Python's: attn = attn + torch.eye(...))
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) {
|
|
|
|
|
- attn[i * tokens_in_chunk + i] += 1.0f;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Step 4: Compute value = attn @ v_beta
|
|
|
|
|
- ggml_compute_value_f32(attn.data(), v_beta_chunk.data(), value.data(), tokens_in_chunk,
|
|
|
|
|
- S_v); // Use actual tokens, not full chunk size
|
|
|
|
|
-
|
|
|
|
|
- // Step 5: Compute k_cumdecay = attn @ (k_beta * g_exp)
|
|
|
|
|
- ggml_compute_k_cumdecay_f32(attn.data(), k_beta_chunk.data(), g_exp.data(), k_cumdecay.data(),
|
|
|
|
|
- tokens_in_chunk, S_k); // Use actual tokens, not full chunk size
|
|
|
|
|
-
|
|
|
|
|
- // Step 6: Compute core attention output for this chunk
|
|
|
|
|
- // First, compute v_new for all tokens in the chunk
|
|
|
|
|
- std::vector<float> v_new_chunk(tokens_in_chunk * S_v);
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) {
|
|
|
|
|
- // v_prime = k_cumdecay @ last_state
|
|
|
|
|
- // k_cumdecay[i] is [S_k], last_state for head h is [S_k, S_v]
|
|
|
|
|
- std::vector<float> v_prime(S_v, 0.0f);
|
|
|
|
|
- for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
|
|
- for (int64_t d2 = 0; d2 < S_k; d2++) {
|
|
|
|
|
- // State index: [d2, d1 + h*S_v] in GGML format
|
|
|
|
|
- int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
|
|
|
|
|
- v_prime[d1] += k_cumdecay[i * S_k + d2] * last_state[state_idx];
|
|
|
|
|
|
|
+
|
|
|
|
|
+ float value_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ value_sum += value[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ value_sum = %f (head %ld, seq %ld)\n", value_sum, head_idx, seq_idx);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute k_cumdecay = attn_precomputed @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
|
|
+ float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ float g_exp = expf(g_vals[j * g_stride0]);
|
|
|
|
|
+ sum += attn_precomputed[i * chunk_size + j] * k_beta_ptr[j * k_beta_stride0 + d * k_beta_stride1] * g_exp;
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // v_new = v_i - v_prime
|
|
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
- v_new_chunk[i * S_v + d] = value[i * S_v + d] - v_prime[d];
|
|
|
|
|
|
|
+ k_cumdecay[i * S_k + d] = sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Now process each token in the chunk to compute output
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) {
|
|
|
|
|
- // q_i @ k_i.T * decay_mask
|
|
|
|
|
- std::vector<float> q_k_attn(chunk_size);
|
|
|
|
|
|
|
+
|
|
|
|
|
+ float k_cumdecay_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_k; i++) {
|
|
|
|
|
+ k_cumdecay_sum += k_cumdecay[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ k_cumdecay_sum = %f (head %ld, seq %ld)\n", k_cumdecay_sum, head_idx, seq_idx);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute fresh attention matrix for this chunk, just like Python reference line 118
|
|
|
|
|
+ // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
|
|
+ float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // First compute q_i @ k_i.transpose(-1, -2)
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < chunk_size; j++) {
|
|
for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
float sum = 0.0f;
|
|
float sum = 0.0f;
|
|
|
for (int64_t d = 0; d < S_k; d++) {
|
|
for (int64_t d = 0; d < S_k; d++) {
|
|
|
- sum += q_chunk[i * S_k + d] * k_chunk[j * S_k + d];
|
|
|
|
|
- }
|
|
|
|
|
- // Apply decay mask - use the precomputed decay mask from src5 tensor
|
|
|
|
|
- if (j <= i) { // Only apply to lower triangular part (i >= j)
|
|
|
|
|
- float decay_val = ggml_get_f32_nd(
|
|
|
|
|
- src5, i, j, h, seq); // [i, j, h, seq] to get exp(g_cumsum[i] - g_cumsum[j]) for head h
|
|
|
|
|
- q_k_attn[j] = sum * decay_val;
|
|
|
|
|
- } else {
|
|
|
|
|
- q_k_attn[j] = 0.0f; // Zero out upper triangular part (i < j)
|
|
|
|
|
|
|
+ float q_val = ((float *)src0->data)[q_offset + d * q_stride0 + i * q_stride1];
|
|
|
|
|
+ float k_val = ((float *)src1->data)[k_offset + d * k_stride0 + j * k_stride1];
|
|
|
|
|
+ sum += q_val * k_val;
|
|
|
}
|
|
}
|
|
|
|
|
+ attn[i * chunk_size + j] = sum * decay_mask[i * chunk_size + j];
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // attn_inter = q_i * g_exp @ last_state
|
|
|
|
|
- // q_chunk[i] is [S_k], g_exp[i] is scalar, last_state for head h is [S_k, S_v]
|
|
|
|
|
- std::vector<float> attn_inter(S_v, 0.0f);
|
|
|
|
|
- for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
|
|
- for (int64_t d2 = 0; d2 < S_k; d2++) {
|
|
|
|
|
- // State index: [d2, d1 + h*S_v] in GGML format
|
|
|
|
|
- int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
|
|
|
|
|
- // Use double precision for the computation and then cast to float
|
|
|
|
|
- double temp_result =
|
|
|
|
|
- (double) q_chunk[i * S_k + d2] * g_exp[i] * (double) last_state[state_idx];
|
|
|
|
|
- attn_inter[d1] += (float) temp_result;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Apply upper triangular mask (masked_fill_(mask, 0))
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t j = i + 1; j < chunk_size; j++) {
|
|
|
|
|
+ attn[i * chunk_size + j] = 0.0f;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // core_attn_out = attn_inter + attn @ v_new
|
|
|
|
|
- // We need to use the attention matrix computed for this position (i)
|
|
|
|
|
- // The attn matrix was computed earlier in the chunk processing
|
|
|
|
|
- // attn @ v_new where attn is [chunk_size, chunk_size] and v_new is [chunk_size, S_v]
|
|
|
|
|
- // For token i, we want sum_j(attn[i, j] * v_new[j, :])
|
|
|
|
|
- std::vector<float> attn_v_new(S_v, 0.0f);
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Compute v_prime = k_cumdecay @ last_recurrent_state
|
|
|
|
|
+ float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t d = 0; d < S_v; d++) {
|
|
for (int64_t d = 0; d < S_v; d++) {
|
|
|
- for (int64_t j = 0; j < tokens_in_chunk; j++) { // Only process actual tokens
|
|
|
|
|
- // Use the attention matrix that was computed for position i
|
|
|
|
|
- // attn[i * chunk_size + j] is the attention from position i to j
|
|
|
|
|
- // v_new_chunk[j * S_v + d] is the v_new value for token j, dimension d
|
|
|
|
|
- attn_v_new[d] += attn[i * chunk_size + j] * v_new_chunk[j * S_v + d];
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Store output - only store for the original sequence length (not the padded part)
|
|
|
|
|
- int64_t global_pos =
|
|
|
|
|
- chunk_idx * chunk_size + i; // Convert local chunk position to global sequence position
|
|
|
|
|
- if (global_pos < n_tokens) { // Make sure we don't exceed the original sequence length
|
|
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
- // Output tensor is [S_v * H_v * n_tokens] for single sequence (n_seqs=1)
|
|
|
|
|
- // Indexing: [dim_idx + head_idx*S_v + pos_idx*S_v*H_v]
|
|
|
|
|
- int64_t ggml_idx = d + h * S_v + global_pos * S_v * H_v;
|
|
|
|
|
- output[ggml_idx] = attn_inter[d] + attn_v_new[d];
|
|
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ for (int64_t k = 0; k < S_k; k++) {
|
|
|
|
|
+ sum += k_cumdecay[i * S_k + k] * last_recurrent_state[k * S_v + d];
|
|
|
}
|
|
}
|
|
|
|
|
+ v_prime[i * S_v + d] = sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Step 7: Update last_recurrent_state
|
|
|
|
|
- std::vector<float> new_state_vec(S_v * S_v * H_v);
|
|
|
|
|
-
|
|
|
|
|
- // Update running cumulative sum with current chunk's values
|
|
|
|
|
- float prev_cumsum = 0.0f; // Cumulative sum from all previous chunks
|
|
|
|
|
- if (chunk_idx > 0) {
|
|
|
|
|
- // Get the cumulative sum of the last token from the previous chunk
|
|
|
|
|
- int64_t prev_chunk_last_token = std::min(chunk_size, n_tokens - (chunk_idx - 1) * chunk_size) - 1;
|
|
|
|
|
- if (prev_chunk_last_token >= 0) {
|
|
|
|
|
- prev_cumsum = running_cumsum[(chunk_idx - 1) * chunk_size + prev_chunk_last_token];
|
|
|
|
|
|
|
+
|
|
|
|
|
+ // Debug prints for key intermediate values
|
|
|
|
|
+ float attn_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
|
|
|
|
|
+ attn_sum += attn[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ attn_sum = %f\n", attn_sum);
|
|
|
|
|
+
|
|
|
|
|
+ // Debug: print first few values of attn matrix
|
|
|
|
|
+ printf("C++ attn first 5 values: %f, %f, %f, %f, %f\n",
|
|
|
|
|
+ attn[0], attn[1], attn[2], attn[3], attn[4]);
|
|
|
|
|
+
|
|
|
|
|
+ float v_prime_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ v_prime_sum += v_prime[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ v_prime_sum = %f\n", v_prime_sum);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute v_new = v_i - v_prime
|
|
|
|
|
+ float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ v_new[i * S_v + d] = ((float *)src2->data)[v_offset + d * v_stride0 + i * v_stride1] - v_prime[i * S_v + d];
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Update running_cumsum for tokens in this chunk
|
|
|
|
|
- for (int64_t t = 0; t < tokens_in_chunk; t++) {
|
|
|
|
|
- int64_t global_pos = chunk_idx * chunk_size + t;
|
|
|
|
|
- if (global_pos < n_tokens) {
|
|
|
|
|
- running_cumsum[global_pos] = prev_cumsum + g_cumsum[t];
|
|
|
|
|
- }
|
|
|
|
|
|
|
+
|
|
|
|
|
+ float v_new_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ v_new_sum += v_new[i];
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Find the last token position in the current chunk (not the entire sequence)
|
|
|
|
|
- int64_t last_pos_in_chunk =
|
|
|
|
|
- std::min((chunk_idx + 1) * chunk_size, n_tokens) - 1; // Last actual token in this chunk
|
|
|
|
|
- if (last_pos_in_chunk >= chunk_idx * chunk_size && last_pos_in_chunk < n_tokens) {
|
|
|
|
|
- float g_last =
|
|
|
|
|
- running_cumsum[last_pos_in_chunk]; // Use the last token's cumulative sum in this chunk
|
|
|
|
|
- // Use double precision for exponential to avoid overflow/underflow
|
|
|
|
|
- double g_last_exp_double = exp((double) g_last);
|
|
|
|
|
- float g_last_exp = (float) g_last_exp_double;
|
|
|
|
|
-
|
|
|
|
|
- // last_state * g_exp[last]
|
|
|
|
|
- for (int64_t i = 0; i < S_k; i++) {
|
|
|
|
|
- for (int64_t j = 0; j < S_v; j++) {
|
|
|
|
|
- // State index: [i, j + h*S_v] in GGML format
|
|
|
|
|
- int64_t state_idx = i * (S_v * H_v) + (j + h * S_v);
|
|
|
|
|
- new_state_vec[i * (S_v * H_v) + (j + h * S_v)] = last_state[state_idx] * g_last_exp;
|
|
|
|
|
|
|
+ printf("C++ v_new_sum = %f\n", v_new_sum);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
|
|
+ float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ float g_exp = expf(g_vals[i * g_stride0]);
|
|
|
|
|
+ for (int64_t k = 0; k < S_k; k++) {
|
|
|
|
|
+ sum += ((float *)src0->data)[q_offset + k * q_stride0 + i * q_stride1] * g_exp * last_recurrent_state[k * S_v + d];
|
|
|
}
|
|
}
|
|
|
|
|
+ attn_inter[i * S_v + d] = sum;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Add (k_i * (g_last - g_i).exp()).T @ v_new
|
|
|
|
|
- // This should be: (k_chunk * g_diff_exp).T @ v_new_chunk
|
|
|
|
|
- // where k_chunk is [chunk_size, S_k], v_new_chunk is [chunk_size, S_v]
|
|
|
|
|
- // result is [S_k, S_v]
|
|
|
|
|
-
|
|
|
|
|
- // First compute v_new for all positions in the chunk
|
|
|
|
|
- std::vector<float> v_new_chunk(chunk_size * S_v);
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only process actual tokens, not full chunk size
|
|
|
|
|
- for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
|
|
- // Recompute v_prime for this position
|
|
|
|
|
- float v_prime = 0.0f;
|
|
|
|
|
- for (int64_t d2 = 0; d2 < S_k; d2++) {
|
|
|
|
|
- // State index: [d2, d1 + h*S_v] in GGML format
|
|
|
|
|
- int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
|
|
|
|
|
- float k_val = k_cumdecay[i * S_k + d2];
|
|
|
|
|
- float s_val = last_state[state_idx];
|
|
|
|
|
- v_prime += k_val * s_val;
|
|
|
|
|
- }
|
|
|
|
|
- v_new_chunk[i * S_v + d1] = value[i * S_v + d1] - v_prime;
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ float attn_inter_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ attn_inter_sum += attn_inter[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ attn_inter_sum = %f\n", attn_inter_sum);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute core_attn_out = attn_inter + attn @ v_new
|
|
|
|
|
+ // Output tensor layout: [S_v * H_v, n_tokens, 1, 1]
|
|
|
|
|
+ const int64_t out_offset = head_idx * (S_v * n_tokens) + chunk_idx * (S_v * chunk_size);
|
|
|
|
|
+ float * core_attn_out = output + out_offset;
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ sum += attn[i * chunk_size + j] * v_new[j * S_v + d];
|
|
|
}
|
|
}
|
|
|
|
|
+ core_attn_out[i * S_v + d] = attn_inter[i * S_v + d] + sum;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Now compute (k_chunk * g_diff_exp).T @ v_new_chunk
|
|
|
|
|
- // This is a matrix multiplication: [S_k, chunk_size] @ [chunk_size, S_v] = [S_k, S_v]
|
|
|
|
|
- // Only process the original sequence length, not the padded chunk size
|
|
|
|
|
- // In the Python reference, this is: (k_i * g_diff_exp).transpose(-1, -2) @ v_new
|
|
|
|
|
- // where g_diff_exp = torch.exp(g_last - g) and g_last = g[-1] (last token in chunk)
|
|
|
|
|
- for (int64_t d1 = 0; d1 < S_k; d1++) {
|
|
|
|
|
- for (int64_t d2 = 0; d2 < S_v; d2++) {
|
|
|
|
|
- float sum = 0.0f;
|
|
|
|
|
- for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only process actual tokens
|
|
|
|
|
- // Get g values for the current chunk from the cumsum tensor (src3)
|
|
|
|
|
- // For state update: g_last (last token in chunk) - g_current (current token)
|
|
|
|
|
- // g tensor has shape [GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs] in GGML format after cumsum and reshaping
|
|
|
|
|
-
|
|
|
|
|
- // Access g_cumsum for current position in chunk - need to access the original g tensor before cumsum
|
|
|
|
|
- // The g_cumsum tensor is src3, but we need the original g values for the diff computation
|
|
|
|
|
- // Actually, we need to access g values that were cumsummed to compute the diff
|
|
|
|
|
-
|
|
|
|
|
- // Get the original g_cumsum values for current and last token in the chunk
|
|
|
|
|
- // g_cumsum values are stored in src3, which was reshaped from [chunk_size, 1, H_v, n_seqs] to [chunk_size, 1, H_v, n_seqs]
|
|
|
|
|
- float g_current = g_cumsum[i]; // Use the g_cumsum computed earlier in this chunk
|
|
|
|
|
- float g_last =
|
|
|
|
|
- g_cumsum[tokens_in_chunk - 1]; // Use the last token's cumsum in this chunk
|
|
|
|
|
-
|
|
|
|
|
- float g_diff = g_last - g_current;
|
|
|
|
|
- float g_diff_exp;
|
|
|
|
|
- // Use double precision for exponential to avoid overflow/underflow
|
|
|
|
|
- // For numerical stability, if g_diff is very negative, exp(g_diff) will be very small
|
|
|
|
|
- if (g_diff < -50.0f) {
|
|
|
|
|
- g_diff_exp = 0.0f; // Set to zero to avoid underflow
|
|
|
|
|
- } else {
|
|
|
|
|
- double g_diff_exp_double = exp((double) g_diff);
|
|
|
|
|
- g_diff_exp = (float) g_diff_exp_double;
|
|
|
|
|
- }
|
|
|
|
|
- sum += k_chunk[i * S_k + d1] * g_diff_exp * v_new_chunk[i * S_v + d2];
|
|
|
|
|
- }
|
|
|
|
|
- // State index: [d1, d2 + h*S_v] in GGML format
|
|
|
|
|
- int64_t state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
|
|
- new_state_vec[state_idx] += sum;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ float core_attn_out_sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
|
|
+ core_attn_out_sum += core_attn_out[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ printf("C++ core_attn_out_sum = %f\n", core_attn_out_sum);
|
|
|
|
|
+
|
|
|
|
|
+ // Update last_recurrent_state
|
|
|
|
|
+ // last_recurrent_state = (
|
|
|
|
|
+ // last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
|
|
|
|
+ // + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
|
|
|
|
+ // )
|
|
|
|
|
+
|
|
|
|
|
+ float g_last = g_vals[chunk_size - 1];
|
|
|
|
|
+ float g_last_exp = expf(g_last);
|
|
|
|
|
+
|
|
|
|
|
+ // First part: last_recurrent_state * g_last_exp
|
|
|
|
|
+ for (int64_t k = 0; k < S_k; k++) {
|
|
|
|
|
+ for (int64_t v = 0; v < S_v; v++) {
|
|
|
|
|
+ last_recurrent_state[k * S_v + v] *= g_last_exp;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Update last_state
|
|
|
|
|
- for (int64_t i = 0; i < S_k; i++) {
|
|
|
|
|
- for (int64_t j = 0; j < S_v; j++) {
|
|
|
|
|
- // State index: [i, j + h*S_v] in GGML format
|
|
|
|
|
- int64_t state_idx = i * (S_v * H_v) + (j + h * S_v);
|
|
|
|
|
- last_state[state_idx] = new_state_vec[state_idx];
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Second part: (k_i * (g_last - g).exp()).transpose(-1, -2) @ v_new
|
|
|
|
|
+ float * k_gated = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ float g_diff_exp = expf(g_last - g_vals[i]);
|
|
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
|
|
+ k_gated[i * S_k + d] = ((float *)src1->data)[k_offset + d * k_stride0 + i * k_stride1] * g_diff_exp;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Compute k_gated.T @ v_new
|
|
|
|
|
+ for (int64_t k = 0; k < S_k; k++) {
|
|
|
|
|
+ for (int64_t v = 0; v < S_v; v++) {
|
|
|
|
|
+ float sum = 0.0f;
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ sum += k_gated[i * S_k + k] * v_new[i * S_v + v];
|
|
|
}
|
|
}
|
|
|
|
|
+ last_recurrent_state[k * S_v + v] += sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- // Copy the final state to the output tensor in the correct GGML layout
|
|
|
|
|
- // GGML expects state layout: [d1, d2 + h*head_dim]
|
|
|
|
|
- for (int64_t h = 0; h < H_v; h++) {
|
|
|
|
|
- for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
|
|
- for (int64_t d2 = 0; d2 < S_v; d2++) {
|
|
|
|
|
- // GGML state index: [d1, d2 + h*head_dim]
|
|
|
|
|
- int64_t ggml_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
|
|
- // Our computed state index: [d1, d2 + h*S_v]
|
|
|
|
|
- int64_t computed_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
|
|
- float val = last_state[computed_state_idx];
|
|
|
|
|
- new_state[ggml_state_idx] = val;
|
|
|
|
|
|
|
+
|
|
|
|
|
+ // Free temporary memory
|
|
|
|
|
+ free(attn);
|
|
|
|
|
+ free(value);
|
|
|
|
|
+ free(k_cumdecay);
|
|
|
|
|
+ free(v_prime);
|
|
|
|
|
+ free(v_new);
|
|
|
|
|
+ free(attn_inter);
|
|
|
|
|
+ free(k_gated);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|