|
|
@@ -10458,7 +10458,7 @@ void ggml_compute_forward_gla(
|
|
|
}
|
|
|
|
|
|
// Helper function to compute cumulative sum
|
|
|
-static void ggml_cumsum_f32(const float * x, float * dst, const int64_t n) {
|
|
|
+static void delta_cumsum_f32(const float * x, float * dst, const int64_t n) {
|
|
|
float cumsum = 0.0f;
|
|
|
for (int64_t i = 0; i < n; i++) {
|
|
|
cumsum += x[i];
|
|
|
@@ -10467,7 +10467,7 @@ static void ggml_cumsum_f32(const float * x, float * dst, const int64_t n) {
|
|
|
}
|
|
|
|
|
|
// Helper function for matrix multiplication
|
|
|
-static void ggml_matmul_f32(const float * a, const float * b, float * dst,
|
|
|
+static void delta_matmul_f32(const float * a, const float * b, float * dst,
|
|
|
const int64_t m, const int64_t n, const int64_t k) {
|
|
|
for (int64_t i = 0; i < m; i++) {
|
|
|
for (int64_t j = 0; j < n; j++) {
|
|
|
@@ -10481,7 +10481,7 @@ static void ggml_matmul_f32(const float * a, const float * b, float * dst,
|
|
|
}
|
|
|
|
|
|
// Helper function to create upper triangular mask
|
|
|
-static void ggml_create_upper_triangular_mask(bool * mask, const int64_t size) {
|
|
|
+static void delta_create_upper_triangular_mask(bool * mask, const int64_t size) {
|
|
|
for (int64_t i = 0; i < size; i++) {
|
|
|
for (int64_t j = 0; j < size; j++) {
|
|
|
mask[i * size + j] = (j >= i); // upper triangular with diagonal
|
|
|
@@ -10505,7 +10505,7 @@ static void ggml_compute_chunk_decay_mask_f32(const float * g_cumsum, float * de
|
|
|
}
|
|
|
|
|
|
// Helper function to compute k_beta @ key.T
|
|
|
-static void ggml_compute_k_beta_key_t_f32(const float * k_beta, const float * key,
|
|
|
+static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * key,
|
|
|
float * k_beta_key_t,
|
|
|
const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
@@ -10522,7 +10522,7 @@ static void ggml_compute_k_beta_key_t_f32(const float * k_beta, const float * ke
|
|
|
}
|
|
|
|
|
|
// Helper function to apply triangular updates
|
|
|
-static void ggml_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
|
|
|
+static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
|
|
|
for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < i; j++) {
|
|
|
float sum = 0.0f;
|
|
|
@@ -10535,14 +10535,14 @@ static void ggml_apply_triangular_updates_f32(float * attn, const int64_t chunk_
|
|
|
}
|
|
|
|
|
|
// Helper function to add identity matrix
|
|
|
-static void ggml_add_identity_matrix_f32(float * matrix, const int64_t size) {
|
|
|
+static void delta_add_identity_matrix_f32(float * matrix, const int64_t size) {
|
|
|
for (int64_t i = 0; i < size; i++) {
|
|
|
matrix[i * size + i] += 1.0f;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Helper function to compute value = attn @ v_beta
|
|
|
-static void ggml_compute_value_f32(const float * attn, const float * v_beta,
|
|
|
+static void delta_compute_value_f32(const float * attn, const float * v_beta,
|
|
|
float * value,
|
|
|
const int64_t chunk_size, const int64_t v_head_dim) {
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
@@ -10558,21 +10558,19 @@ static void ggml_compute_value_f32(const float * attn, const float * v_beta,
|
|
|
}
|
|
|
|
|
|
// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
-static void ggml_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const double * g_exp,
|
|
|
- float * k_cumdecay,
|
|
|
- const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
+static void delta_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const float * g,
|
|
|
+ float * k_cumdecay, const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t d = 0; d < k_head_dim; d++) {
|
|
|
float sum = 0.0f;
|
|
|
for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
int64_t k_beta_idx = j * k_head_dim + d;
|
|
|
- sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * g_exp[j];
|
|
|
+ sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * expf(g[j]);
|
|
|
}
|
|
|
k_cumdecay[i * k_head_dim + d] = sum;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-// Helper functions for delta net computation
|
|
|
|
|
|
// Matrix multiplication helper for delta net
|
|
|
static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, const int64_t cols_a, const int64_t cols_b,
|
|
|
@@ -10590,47 +10588,76 @@ static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, con
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// Helper function to apply triangular updates to attention matrix
|
|
|
-static void delta_net_compute_diagonal_updates(
|
|
|
- float * attn,
|
|
|
- const int64_t chunk_size,
|
|
|
- const int64_t n_heads,
|
|
|
- const int64_t n_seqs) {
|
|
|
-
|
|
|
- // Apply triangular updates like in the Python reference:
|
|
|
- // for i in range(1, chunk_size):
|
|
|
- // row = attn[..., i, :i].clone()
|
|
|
- // sub = attn[..., :i, :i].clone()
|
|
|
- // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
- // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
|
|
-
|
|
|
- for (int64_t head = 0; head < n_heads; head++) {
|
|
|
- for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
- // Get pointer to this head's attention matrix
|
|
|
- float * attn_head = attn + (head * chunk_size * chunk_size * n_seqs) + (seq * chunk_size * chunk_size);
|
|
|
-
|
|
|
- // Create temporary storage for the original values to avoid in-place modification
|
|
|
- float * attn_copy = (float *) malloc(chunk_size * chunk_size * sizeof(float));
|
|
|
- memcpy(attn_copy, attn_head, chunk_size * chunk_size * sizeof(float));
|
|
|
-
|
|
|
- // Apply triangular updates using the original values
|
|
|
- for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
- for (int64_t j = 0; j < i; j++) {
|
|
|
- float sum = 0.0f;
|
|
|
- // Compute (row.unsqueeze(-1) * sub).sum(-2) using original values
|
|
|
- for (int64_t k = 0; k < i; k++) {
|
|
|
- sum += attn_copy[i * chunk_size + k] * attn_copy[k * chunk_size + j];
|
|
|
- }
|
|
|
- attn_head[i * chunk_size + j] = attn_copy[i * chunk_size + j] + sum;
|
|
|
- }
|
|
|
+// Helper function to compute q_i @ k_i.transpose(-1, -2) * decay_mask and apply mask
|
|
|
+static void delta_compute_q_k_attn_f32(const float * q, const float * k, const float * decay_mask,
|
|
|
+ float * attn, const bool * mask,
|
|
|
+ const int64_t chunk_size, const int64_t head_dim) {
|
|
|
+ // Compute q @ k.transpose(-1, -2)
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t d = 0; d < head_dim; d++) {
|
|
|
+ int64_t q_idx = i * head_dim + d;
|
|
|
+ int64_t k_idx = j * head_dim + d;
|
|
|
+ sum += q[q_idx] * k[k_idx];
|
|
|
+ }
|
|
|
+ // Apply decay mask and causal mask
|
|
|
+ int64_t attn_idx = i * chunk_size + j;
|
|
|
+ attn[attn_idx] = (mask[attn_idx] ? 0.0f : sum * decay_mask[attn_idx]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function for matrix multiplication with state tensors
|
|
|
+static void delta_matmul_state_f32(const float * a, const float * state, float * dst,
|
|
|
+ const int64_t rows_a, const int64_t cols_a, const int64_t cols_state) {
|
|
|
+ for (int64_t i = 0; i < rows_a; i++) {
|
|
|
+ for (int64_t j = 0; j < cols_state; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t k = 0; k < cols_a; k++) {
|
|
|
+ int64_t a_idx = i * cols_a + k;
|
|
|
+ int64_t state_idx = k * cols_state + j;
|
|
|
+ sum += a[a_idx] * state[state_idx];
|
|
|
}
|
|
|
+ dst[i * cols_state + j] = sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function for element-wise tensor subtraction
|
|
|
+static void delta_tensor_subtract_f32(const float * a, const float * b, float * dst, const int64_t size) {
|
|
|
+ for (int64_t i = 0; i < size; i++) {
|
|
|
+ dst[i] = a[i] - b[i];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function for element-wise tensor addition
|
|
|
+static void delta_tensor_add_f32(const float * a, const float * b, float * dst, const int64_t size) {
|
|
|
+ for (int64_t i = 0; i < size; i++) {
|
|
|
+ dst[i] = a[i] + b[i];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to update recurrent state
|
|
|
+static void delta_update_recurrent_state_f32(const float * last_state, const float * g_last,
|
|
|
+ const float * k_i, const float * g_diff_exp, const float * v_new, float * new_state,
|
|
|
+ const int64_t chunk_size, const int64_t k_head_dim, const int64_t v_head_dim) {
|
|
|
+ for (int64_t i = 0; i < k_head_dim; i++) {
|
|
|
+ for (int64_t j = 0; j < v_head_dim; j++) {
|
|
|
+ int64_t state_idx = i * v_head_dim + j;
|
|
|
|
|
|
- // Add identity matrix
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- attn_head[i * chunk_size + i] += 1.0f;
|
|
|
+ // last_recurrent_state * g_last
|
|
|
+ float term1 = last_state[state_idx] * (*g_last);
|
|
|
+
|
|
|
+ // (k_i * g_diff_exp).transpose(-1, -2) @ v_new
|
|
|
+ float term2 = 0.0f;
|
|
|
+ for (int64_t k = 0; k < chunk_size; k++) {
|
|
|
+ int64_t k_idx = k * k_head_dim + i;
|
|
|
+ int64_t v_idx = k * v_head_dim + j;
|
|
|
+ term2 += k_i[k_idx] * g_diff_exp[k] * v_new[v_idx];
|
|
|
}
|
|
|
|
|
|
- free(attn_copy);
|
|
|
+ new_state[state_idx] = term1 + term2;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -10650,10 +10677,10 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
const int64_t S_k = (int64_t) dst->op_params[1];
|
|
|
const int64_t S_v = (int64_t) dst->op_params[2];
|
|
|
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
|
- const int64_t H_k = H_v;
|
|
|
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
|
const int64_t n_seqs = src0->ne[3]; // q tensor has n_seqs in dim 3
|
|
|
|
|
|
+
|
|
|
// Add assertions to verify tensor dimensions
|
|
|
GGML_ASSERT(src0->ne[3] == n_seqs); // q tensor
|
|
|
GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
|
|
|
@@ -10669,14 +10696,13 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
const int ith = params->ith;
|
|
|
// const int nth = params->nth; // nth is unused
|
|
|
|
|
|
- // For chunked implementation, we process all sequences in thread 0 for simplicity
|
|
|
- // This can be optimized later to parallelize across sequences
|
|
|
+ // TODO: parallelize across heads and sequences
|
|
|
if (ith != 0) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
// Clear output and new state section
|
|
|
- memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
|
|
|
+ memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
|
|
|
|
|
|
// Calculate chunk size
|
|
|
const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
|
@@ -10684,518 +10710,164 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
|
|
|
|
|
|
// Apply triangular updates to the precomputed attention matrix
|
|
|
- // This is the missing piece that was causing the attention matrix to have all zeros
|
|
|
float * attn_data = (float *) src8->data;
|
|
|
- delta_net_compute_diagonal_updates(attn_data, chunk_size, H_v, n_seqs);
|
|
|
-
|
|
|
- // Debug: Check attention matrix after triangular updates
|
|
|
- float attn_after_updates_sum = 0.0f;
|
|
|
- float attn_after_updates_max = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * chunk_size * H_v * n_seqs; i++) {
|
|
|
- attn_after_updates_sum += attn_data[i];
|
|
|
- attn_after_updates_max = fmaxf(attn_after_updates_max, fabsf(attn_data[i]));
|
|
|
- }
|
|
|
-
|
|
|
- for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
|
|
|
- for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
|
|
|
- const int64_t v_beta_offset = (head_idx * src6->nb[2] + seq_idx * src6->nb[3]) / sizeof(float);
|
|
|
- const int64_t k_beta_offset = (head_idx * src7->nb[2] + seq_idx * src7->nb[3]) / sizeof(float);
|
|
|
- const int64_t attn_offset = (head_idx * src8->nb[2] + seq_idx * src8->nb[3]) / sizeof(float);
|
|
|
- const int64_t g_offset = (head_idx * src3->nb[2] + seq_idx * src3->nb[3]) / sizeof(float);
|
|
|
-
|
|
|
- // Fixed memory access patterns with bounds checking
|
|
|
- float * attn_precomputed = (float *) src8->data + attn_offset;
|
|
|
- float * v_beta_ptr = (float *) src6->data + v_beta_offset;
|
|
|
- float * k_beta_ptr = (float *) src7->data + k_beta_offset;
|
|
|
- float * g_vals = (float *) src3->data + g_offset;
|
|
|
-
|
|
|
- // Add bounds checking to prevent out-of-bounds access
|
|
|
- const int64_t attn_total_elements = src8->ne[0] * src8->ne[1] * src8->ne[2] * src8->ne[3];
|
|
|
- const int64_t v_beta_total_elements = src6->ne[0] * src6->ne[1] * src6->ne[2] * src6->ne[3];
|
|
|
- const int64_t k_beta_total_elements = src7->ne[0] * src7->ne[1] * src7->ne[2] * src7->ne[3];
|
|
|
- const int64_t g_total_elements = src3->ne[0] * src3->ne[1] * src3->ne[2] * src3->ne[3];
|
|
|
-
|
|
|
- // Compute value = attn @ v_beta with deterministic tensor access
|
|
|
- // printf("C++ DEBUG: Computing value = attn @ v_beta with deterministic tensor access\n");
|
|
|
- float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
-
|
|
|
- // Calculate tensor strides for deterministic access
|
|
|
- const int64_t attn_stride0 = src8->nb[0] / sizeof(float); // chunk_size dimension
|
|
|
- const int64_t v_beta_stride0 = src6->nb[0] / sizeof(float); // S_v dimension
|
|
|
- const int64_t v_beta_stride1 = src6->nb[1] / sizeof(float); // chunk_size dimension
|
|
|
-
|
|
|
- // printf("C++ DEBUG: Tensor strides for deterministic access:\n");
|
|
|
- // printf(" attn_stride0=%ld, v_beta_stride0=%ld, v_beta_stride1=%ld\n",
|
|
|
- // attn_stride0, v_beta_stride0, v_beta_stride1);
|
|
|
-
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- // Deterministic tensor access using stride calculations
|
|
|
- // attn[i][j] access: i * attn_stride0 + j
|
|
|
- int64_t attn_idx = i * attn_stride0 + j;
|
|
|
- if (attn_idx >= chunk_size * chunk_size) {
|
|
|
- // printf("ERROR: attn access out of bounds: attn_idx=%ld, max=%ld\n",
|
|
|
- // attn_idx, chunk_size * chunk_size);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- // v_beta[j][d] access: j * v_beta_stride1 + d
|
|
|
- int64_t v_beta_idx = j * v_beta_stride1 + d;
|
|
|
- if (v_beta_idx >= chunk_size * S_v) {
|
|
|
- // printf("ERROR: v_beta access out of bounds: v_beta_idx=%ld, max=%ld\n",
|
|
|
- // v_beta_idx, chunk_size * S_v);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- float attn_val = attn_precomputed[attn_idx];
|
|
|
- float v_beta_val = v_beta_ptr[v_beta_idx];
|
|
|
-
|
|
|
- if (isnan(attn_val) || isnan(v_beta_val)) {
|
|
|
- // printf("ERROR: NaN detected in matrix multiplication: attn=%f, v_beta=%f\n", attn_val, v_beta_val);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- // Debug: Print first few multiplications for validation
|
|
|
- if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2 && j < 2) {
|
|
|
- // printf("C++ DEBUG value[%ld][%ld]: attn[%ld][%ld]=%f * v_beta[%ld][%ld]=%f = %f\n",
|
|
|
- // i, d, i, j, attn_val, j, d, v_beta_val, attn_val * v_beta_val);
|
|
|
- }
|
|
|
- sum += attn_val * v_beta_val;
|
|
|
- }
|
|
|
- value[i * S_v + d] = sum;
|
|
|
- // Debug: Print first few results for validation
|
|
|
- if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2) {
|
|
|
- // printf("C++ DEBUG value[%ld][%ld] = sum = %f\n", i, d, value[i * S_v + d]);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- float value_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- value_sum += value[i];
|
|
|
- }
|
|
|
- // printf("C++ PRE-CHUNK value_sum = %f (head %ld, seq %ld)\n", value_sum, head_idx, seq_idx);
|
|
|
-
|
|
|
- // Compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) with deterministic tensor access
|
|
|
- // printf("C++ DEBUG: Computing k_cumdecay = attn @ (k_beta * g.exp()) with deterministic tensor access\n");
|
|
|
- float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
-
|
|
|
- // Calculate tensor strides for deterministic access
|
|
|
- const int64_t k_beta_stride0 = src7->nb[0] / sizeof(float); // S_k dimension
|
|
|
- const int64_t k_beta_stride1 = src7->nb[1] / sizeof(float); // chunk_size dimension
|
|
|
- const int64_t g_stride0 = src3->nb[0] / sizeof(float); // chunk_size dimension
|
|
|
-
|
|
|
- // printf("C++ DEBUG: k_cumdecay tensor strides: k_beta_stride0=%ld, k_beta_stride1=%ld, g_stride0=%ld\n",
|
|
|
- // k_beta_stride0, k_beta_stride1, g_stride0);
|
|
|
-
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- // Deterministic tensor access using stride calculations
|
|
|
- // attn[i][j] access: i * attn_stride0 + j
|
|
|
- int64_t attn_idx = i * attn_stride0 + j;
|
|
|
- if (attn_idx >= chunk_size * chunk_size) {
|
|
|
- // printf("ERROR: attn access out of bounds in k_cumdecay: attn_idx=%ld, max=%ld\n",
|
|
|
- // attn_idx, chunk_size * chunk_size);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- // k_beta[j][d] access: j * k_beta_stride1 + d
|
|
|
- int64_t k_beta_idx = j * k_beta_stride1 + d;
|
|
|
- if (k_beta_idx >= chunk_size * S_k) {
|
|
|
- // printf("ERROR: k_beta access out of bounds: k_beta_idx=%ld, max=%ld\n",
|
|
|
- // k_beta_idx, chunk_size * S_k);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- // g tensor layout: [chunk_size, n_heads, n_seqs, 1]
|
|
|
- // Deterministic access: g[j + head_idx * chunk_size + seq_idx * chunk_size * n_heads]
|
|
|
- int64_t g_idx = j + head_idx * chunk_size + seq_idx * chunk_size * H_v;
|
|
|
- if (g_idx >= chunk_size * H_v * n_seqs) {
|
|
|
- // printf("ERROR: g tensor out of bounds: g_idx=%ld, max=%ld\n",
|
|
|
- // g_idx, chunk_size * H_v * n_seqs);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- float attn_val = attn_precomputed[attn_idx];
|
|
|
- float k_beta_val = k_beta_ptr[k_beta_idx];
|
|
|
- float g_val = g_vals[g_idx];
|
|
|
- float g_exp = expf(g_val);
|
|
|
-
|
|
|
- if (isnan(attn_val) || isnan(k_beta_val) || isnan(g_val) || isnan(g_exp)) {
|
|
|
- // printf("ERROR: NaN detected in k_cumdecay multiplication: attn=%f, k_beta=%f, g_val=%f, g_exp=%f\n",
|
|
|
- // attn_val, k_beta_val, g_val, g_exp);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- // Debug: Print first few multiplications for validation
|
|
|
- if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2 && j < 2) {
|
|
|
- // printf("C++ DEBUG k_cumdecay[%ld][%ld]: attn[%ld][%ld]=%f * k_beta[%ld][%ld]=%f * g_exp[%ld]=%f = %f\n",
|
|
|
- // i, d, i, j, attn_val, j, d, k_beta_val, j, g_exp,
|
|
|
- // attn_val * k_beta_val * g_exp);
|
|
|
- }
|
|
|
- sum += attn_val * k_beta_val * g_exp;
|
|
|
- }
|
|
|
- k_cumdecay[i * S_k + d] = sum;
|
|
|
- // Debug: Print first few results for validation
|
|
|
- if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2) {
|
|
|
- // printf("C++ DEBUG k_cumdecay[%ld][%ld] = sum = %f\n", i, d, k_cumdecay[i * S_k + d]);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- float k_cumdecay_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_k; i++) {
|
|
|
- k_cumdecay_sum += k_cumdecay[i];
|
|
|
- }
|
|
|
-
|
|
|
- free(value);
|
|
|
- free(k_cumdecay);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Initialize last_recurrent_state
|
|
|
- // last_recurrent_state = torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
|
|
|
- // if initial_state is None else initial_state.to(value)
|
|
|
- float * initial_state_ptr = (float *) src4->data;
|
|
|
-
|
|
|
- // If initial_state is provided, copy it to new_state, otherwise initialize new_state to zeros
|
|
|
- // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
- // This means: [n_heads * v_head_dim, v_head_dim * n_seqs, 1, 1]
|
|
|
- // So total size is: S_v * H_v * S_v * n_seqs
|
|
|
- if (initial_state_ptr != NULL) {
|
|
|
- // Copy initial state to new state
|
|
|
- memcpy(new_state, initial_state_ptr, S_v * H_v * S_v * n_seqs * sizeof(float));
|
|
|
- } else {
|
|
|
- // Initialize new state to zeros
|
|
|
- memset(new_state, 0, S_v * H_v * S_v * n_seqs * sizeof(float));
|
|
|
- }
|
|
|
-
|
|
|
- // Process each chunk for the main computation
|
|
|
- // Following the Python reference implementation exactly
|
|
|
- for (int64_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) {
|
|
|
- for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
|
|
|
- // Process each head in this chunk
|
|
|
- for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
|
|
|
- // Get the recurrent state for this sequence and head
|
|
|
- // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
- // For each head and sequence, we need to find the correct state slice
|
|
|
- // The state is organized as: [head_idx * S_v * S_v * n_seqs + seq_idx * S_v * S_v]
|
|
|
- float * last_recurrent_state = new_state + (head_idx * S_v * S_v * n_seqs) + (seq_idx * S_v * S_v);
|
|
|
- // printf("\n=== C++ Processing chunk %ld, seq %ld, head %ld ===\n", chunk_idx, seq_idx, head_idx);
|
|
|
- // Get pointers to current chunk data for this head
|
|
|
- // GGML tensor layout: [S_k/S_v, chunk_size, H_v, n_seqs]
|
|
|
- // Python layout: [batch_size, sequence_length, num_heads, k_head_dim]
|
|
|
- // After transpose: [batch_size, num_heads, sequence_length, k_head_dim]
|
|
|
-
|
|
|
- // For GGML: ne[0]=S_k/S_v, ne[1]=chunk_size, ne[2]=H_v, ne[3]=n_seqs
|
|
|
- // nb[0]=sizeof(float)*S_k/S_v, nb[1]=sizeof(float)*S_k/S_v*chunk_size, etc.
|
|
|
-
|
|
|
- const int64_t q_offset = (seq_idx * src0->nb[3] + head_idx * src0->nb[2]) / sizeof(float);
|
|
|
- const int64_t k_offset = (seq_idx * src1->nb[3] + head_idx * src1->nb[2]) / sizeof(float);
|
|
|
- const int64_t v_offset = (seq_idx * src2->nb[3] + head_idx * src2->nb[2]) / sizeof(float);
|
|
|
- const int64_t g_offset = (seq_idx * src3->nb[3] + head_idx * src3->nb[2]) / sizeof(float);
|
|
|
- const int64_t v_beta_offset = (seq_idx * src6->nb[3] + head_idx * src6->nb[2]) / sizeof(float);
|
|
|
- const int64_t k_beta_offset = (seq_idx * src7->nb[3] + head_idx * src7->nb[2]) / sizeof(float);
|
|
|
- const int64_t attn_offset = (seq_idx * src8->nb[3] + head_idx * src8->nb[2]) / sizeof(float);
|
|
|
- const int64_t decay_mask_offset = (seq_idx * src5->nb[3] + head_idx * src5->nb[2]) / sizeof(float);
|
|
|
-
|
|
|
- // Calculate strides for each tensor
|
|
|
- const int64_t q_stride0 = src0->nb[0] / sizeof(float); // S_k
|
|
|
- const int64_t q_stride1 = src0->nb[1] / sizeof(float); // chunk_size
|
|
|
- const int64_t k_stride0 = src1->nb[0] / sizeof(float); // S_k
|
|
|
- const int64_t k_stride1 = src1->nb[1] / sizeof(float); // chunk_size
|
|
|
- const int64_t v_stride0 = src2->nb[0] / sizeof(float); // S_v
|
|
|
- const int64_t v_stride1 = src2->nb[1] / sizeof(float); // chunk_size
|
|
|
- const int64_t g_stride0 = src3->nb[0] / sizeof(float); // chunk_size
|
|
|
- const int64_t v_beta_stride0 = src6->nb[0] / sizeof(float); // S_v
|
|
|
- const int64_t v_beta_stride1 = src6->nb[1] / sizeof(float); // chunk_size
|
|
|
- const int64_t k_beta_stride0 = src7->nb[0] / sizeof(float); // S_k
|
|
|
- const int64_t k_beta_stride1 = src7->nb[1] / sizeof(float); // chunk_size
|
|
|
- const int64_t attn_stride0 = src8->nb[0] / sizeof(float); // chunk_size
|
|
|
- const int64_t attn_stride1 = src8->nb[1] / sizeof(float); // chunk_size
|
|
|
- const int64_t decay_mask_stride0 = src5->nb[0] / sizeof(float); // chunk_size
|
|
|
- const int64_t decay_mask_stride1 = src5->nb[1] / sizeof(float); // chunk_size
|
|
|
-
|
|
|
- // Get decay mask for this chunk and head
|
|
|
- float * decay_mask = (float *) src5->data + decay_mask_offset;
|
|
|
-
|
|
|
- // Use pre-computed attention matrix from src8 (after triangular updates)
|
|
|
- // The Python reference computes triangular updates before the chunk loop
|
|
|
- float * attn_precomputed = (float *) src8->data + attn_offset;
|
|
|
-
|
|
|
- // Debug: print precomputed attention matrix values
|
|
|
- float attn_precomputed_sum = 0.0f;
|
|
|
- float attn_precomputed_max = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
|
|
|
- attn_precomputed_sum += attn_precomputed[i];
|
|
|
- attn_precomputed_max = fmaxf(attn_precomputed_max, fabsf(attn_precomputed[i]));
|
|
|
- }
|
|
|
- // Get g values for this chunk and head
|
|
|
- float * g_vals = (float *) src3->data + g_offset;
|
|
|
-
|
|
|
- // Get v_beta and k_beta for this chunk and head
|
|
|
- float * v_beta_ptr = (float *) src6->data + v_beta_offset;
|
|
|
- float * k_beta_ptr = (float *) src7->data + k_beta_offset;
|
|
|
-
|
|
|
- // Debug: print v_beta and k_beta values
|
|
|
- float v_beta_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- v_beta_sum += v_beta_ptr[i];
|
|
|
- }
|
|
|
- // printf("C++ v_beta_sum = %f\n", v_beta_sum);
|
|
|
-
|
|
|
- float k_beta_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_k; i++) {
|
|
|
- k_beta_sum += k_beta_ptr[i];
|
|
|
- }
|
|
|
- // printf("C++ k_beta_sum = %f\n", k_beta_sum);
|
|
|
-
|
|
|
- // Compute value = attn_precomputed @ v_beta with bounds checking
|
|
|
- float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- // Bounds checking for matrix multiplication
|
|
|
- if (i * chunk_size + j >= chunk_size * chunk_size ||
|
|
|
- j * v_beta_stride0 + d * v_beta_stride1 >= chunk_size * S_v) {
|
|
|
- // printf("ERROR: Chunk value matrix multiplication out of bounds: i=%ld, j=%ld, d=%ld\n", i, j, d);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- float attn_val = attn_precomputed[i * chunk_size + j];
|
|
|
- float v_beta_val = v_beta_ptr[j * v_beta_stride0 + d * v_beta_stride1];
|
|
|
-
|
|
|
- // Check for NaN values to prevent propagation
|
|
|
- if (isnan(attn_val) || isnan(v_beta_val)) {
|
|
|
- // printf("ERROR: NaN detected in chunk value multiplication: attn=%f, v_beta=%f\n", attn_val, v_beta_val);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- sum += attn_val * v_beta_val;
|
|
|
- }
|
|
|
- value[i * S_v + d] = sum;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- float value_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- value_sum += value[i];
|
|
|
- }
|
|
|
-
|
|
|
- // Compute k_cumdecay = attn_precomputed @ (k_beta * g.exp().unsqueeze(-1)) with bounds checking
|
|
|
- float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- // Bounds checking for matrix multiplication
|
|
|
- if (i * chunk_size + j >= chunk_size * chunk_size ||
|
|
|
- j * k_beta_stride0 + d * k_beta_stride1 >= chunk_size * S_k ||
|
|
|
- j * g_stride0 >= chunk_size) {
|
|
|
- // printf("ERROR: Chunk k_cumdecay matrix multiplication out of bounds: i=%ld, j=%ld, d=%ld\n", i, j, d);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- float attn_val = attn_precomputed[i * chunk_size + j];
|
|
|
- float k_beta_val = k_beta_ptr[j * k_beta_stride0 + d * k_beta_stride1];
|
|
|
- float g_val = g_vals[j * g_stride0];
|
|
|
- float g_exp = expf(g_val);
|
|
|
-
|
|
|
- // Check for NaN values to prevent propagation
|
|
|
- if (isnan(attn_val) || isnan(k_beta_val) || isnan(g_val) || isnan(g_exp)) {
|
|
|
- // printf("ERROR: NaN detected in chunk k_cumdecay multiplication: attn=%f, k_beta=%f, g_val=%f, g_exp=%f\n",
|
|
|
- // attn_val, k_beta_val, g_val, g_exp);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- sum += attn_val * k_beta_val * g_exp;
|
|
|
- }
|
|
|
- k_cumdecay[i * S_k + d] = sum;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- float k_cumdecay_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_k; i++) {
|
|
|
- k_cumdecay_sum += k_cumdecay[i];
|
|
|
- }
|
|
|
+ float * v_beta_data = (float *) src6->data;
|
|
|
+ float * k_beta_data = (float *) src7->data;
|
|
|
+ float * g_data = (float *) src3->data;
|
|
|
+ float * q_data = (float *) src0->data;
|
|
|
+ float * k_data = (float *) src1->data;
|
|
|
+ //float * v_data = (float *) src2->data;
|
|
|
+ float * state_data = (float *) src4->data;
|
|
|
+ float * decay_mask_data = (float *) src5->data;
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src2));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src3));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src4));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src5));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src6));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src7));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src8));
|
|
|
+
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
|
|
|
+ float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
+ float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
+ float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
+ delta_apply_triangular_updates_f32(attn_data_for_chs, chunk_size);
|
|
|
+ delta_add_identity_matrix_f32(attn_data_for_chs, chunk_size);
|
|
|
+ // Calculate the correct v_beta and k_beta pointers for this head and sequence
|
|
|
+ float * v_beta_chunk = v_beta_data + (src6->nb[3] / sizeof(float)) * seq + (src6->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
+ float * k_beta_chunk = k_beta_data + (src7->nb[3] / sizeof(float)) * seq + (src7->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
+ // The g tensor has dimensions [8, 64, 2, 1] = [features, tokens, heads, sequences]
|
|
|
+ // We need to access the correct head data
|
|
|
+ // For each head, we need to access the correct feature for all tokens in the chunk
|
|
|
+ // Let's try accessing feature index chunk (since we have 8 features and chunk=0)
|
|
|
+ float * g_chunk = g_data + (src3->nb[3] / sizeof(float)) * seq + (src3->nb[2] / sizeof(float)) * head + (src3->nb[1] / sizeof(float)) * (chunk * chunk_size);
|
|
|
+ delta_compute_value_f32(attn_data_for_chs, v_beta_chunk, value_chunk, chunk_size, S_v);
|
|
|
+ delta_compute_k_cumdecay_f32(attn_data_for_chs, k_beta_chunk, g_chunk, k_cumdecay, chunk_size, S_k);
|
|
|
+ // Now compute the per-chunk-specific part (corresponding to the inner loop in Python)
|
|
|
+ float * q_chunk = q_data + (src0->nb[3] / sizeof(float)) * seq + (src0->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
+ float * k_chunk = k_data + (src1->nb[3] / sizeof(float)) * seq + (src1->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
+ float * decay_mask_chunk = decay_mask_data + (src5->nb[3] / sizeof(float)) * seq + (src5->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
+ float * k_cumdecay_chunk = k_cumdecay + (S_v * chunk_size * H_v) * seq + (S_v * chunk_size) * head;
|
|
|
|
|
|
- // Compute fresh attention matrix for this chunk, just like Python reference line 118
|
|
|
- // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
+ // Allocate temporary variables for the loop
|
|
|
float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
|
|
|
+ float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ float * core_attn_out_chunk = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ float * g_last = (float *) malloc(sizeof(float));
|
|
|
+ float * g_diff_exp = (float *) malloc(chunk_size * sizeof(float));
|
|
|
+ bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
|
|
|
|
|
|
- // First compute q_i @ k_i.transpose(-1, -2) with bounds checking
|
|
|
+ // Create upper triangular mask for causal attention (exclude diagonal)
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
- // Bounds checking for q and k tensor access
|
|
|
- int64_t q_idx = q_offset + d * q_stride0 + i * q_stride1;
|
|
|
- int64_t k_idx = k_offset + d * k_stride0 + j * k_stride1;
|
|
|
-
|
|
|
- if (q_idx >= src0->ne[0] * src0->ne[1] * src0->ne[2] * src0->ne[3] ||
|
|
|
- k_idx >= src1->ne[0] * src1->ne[1] * src1->ne[2] * src1->ne[3]) {
|
|
|
- // printf("ERROR: q/k tensor access out of bounds: q_idx=%ld, k_idx=%ld\n", q_idx, k_idx);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- float q_val = ((float *)src0->data)[q_idx];
|
|
|
- float k_val = ((float *)src1->data)[k_idx];
|
|
|
-
|
|
|
- // Check for NaN values to prevent propagation
|
|
|
- if (isnan(q_val) || isnan(k_val)) {
|
|
|
- // printf("ERROR: NaN detected in q@k multiplication: q_val=%f, k_val=%f\n", q_val, k_val);
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- sum += q_val * k_val;
|
|
|
- }
|
|
|
-
|
|
|
- // Bounds checking for decay mask access
|
|
|
- if (i * chunk_size + j >= chunk_size * chunk_size) {
|
|
|
- // printf("ERROR: decay mask access out of bounds: i=%ld, j=%ld\n", i, j);
|
|
|
- attn[i * chunk_size + j] = 0.0f;
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- float decay_val = decay_mask[i * chunk_size + j];
|
|
|
- if (isnan(decay_val)) {
|
|
|
- // printf("ERROR: NaN detected in decay mask: decay_val=%f\n", decay_val);
|
|
|
- attn[i * chunk_size + j] = 0.0f;
|
|
|
- continue;
|
|
|
- }
|
|
|
-
|
|
|
- attn[i * chunk_size + j] = sum * decay_val;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Apply upper triangular mask (masked_fill_(mask, 0))
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t j = i + 1; j < chunk_size; j++) {
|
|
|
- attn[i * chunk_size + j] = 0.0f;
|
|
|
+ mask[i * chunk_size + j] = (j > i); // True for upper triangular (excluding diagonal)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ // Python loop implementation:
|
|
|
+ // q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
|
|
+ // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
+ delta_compute_q_k_attn_f32(q_chunk, k_chunk, decay_mask_chunk, attn, mask, chunk_size, S_k);
|
|
|
|
|
|
- // Compute v_prime = k_cumdecay @ last_recurrent_state
|
|
|
- float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t k = 0; k < S_k; k++) {
|
|
|
- sum += k_cumdecay[i * S_k + k] * last_recurrent_state[k * S_v + d];
|
|
|
- }
|
|
|
- v_prime[i * S_v + d] = sum;
|
|
|
- }
|
|
|
- }
|
|
|
+ // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
+ // Calculate the correct state pointer for this head and sequence
|
|
|
+ float * head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
|
|
|
|
|
|
- // Debug prints for key intermediate values
|
|
|
- float attn_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
|
|
|
- attn_sum += attn[i];
|
|
|
- }
|
|
|
|
|
|
+ delta_matmul_state_f32(k_cumdecay_chunk, head_state_data, v_prime, chunk_size, S_k, S_v);
|
|
|
|
|
|
- float v_prime_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- v_prime_sum += v_prime[i];
|
|
|
- }
|
|
|
+ // v_new = v_i - v_prime
|
|
|
+ delta_tensor_subtract_f32(value_chunk, v_prime, v_new, chunk_size * S_v);
|
|
|
|
|
|
- // Compute v_new = v_i - v_prime
|
|
|
- float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
+ float * q_g_exp = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
- v_new[i * S_v + d] = ((float *)src2->data)[v_offset + d * v_stride0 + i * v_stride1] - v_prime[i * S_v + d];
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
+ int64_t q_idx = i * S_k + d;
|
|
|
+ q_g_exp[q_idx] = q_chunk[q_idx] * expf(g_chunk[i]);
|
|
|
}
|
|
|
}
|
|
|
+ delta_matmul_state_f32(q_g_exp, head_state_data, attn_inter, chunk_size, S_k, S_v);
|
|
|
|
|
|
- float v_new_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- v_new_sum += v_new[i];
|
|
|
- }
|
|
|
+ // core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
+ float * attn_v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ delta_matmul_state_f32(attn, v_new, attn_v_new, chunk_size, chunk_size, S_v);
|
|
|
+ delta_tensor_add_f32(attn_inter, attn_v_new, core_attn_out_chunk, chunk_size * S_v);
|
|
|
|
|
|
- // Compute attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
- float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
+ // Store the result in the output tensor
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t d = 0; d < S_v; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- float g_exp = expf(g_vals[i * g_stride0]);
|
|
|
- for (int64_t k = 0; k < S_k; k++) {
|
|
|
- sum += ((float *)src0->data)[q_offset + k * q_stride0 + i * q_stride1] * g_exp * last_recurrent_state[k * S_v + d];
|
|
|
- }
|
|
|
- attn_inter[i * S_v + d] = sum;
|
|
|
+ if ((chunk * chunk_size + i) >= n_tokens) continue;
|
|
|
+ int64_t output_idx = seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
|
|
|
+ output[output_idx] = core_attn_out_chunk[i * S_v + d];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- float attn_inter_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- attn_inter_sum += attn_inter[i];
|
|
|
- }
|
|
|
-
|
|
|
- // Compute core_attn_out = attn_inter + attn @ v_new
|
|
|
- // Output tensor layout: [S_v * H_v, n_tokens, 1, 1]
|
|
|
- const int64_t out_offset = head_idx * (S_v * n_tokens) + chunk_idx * (S_v * chunk_size);
|
|
|
- float * core_attn_out = output + out_offset;
|
|
|
+ // g_last = g[:, :, i, -1, None, None].exp()
|
|
|
+ *g_last = expf(g_chunk[chunk_size - 1]);
|
|
|
|
|
|
+ // Prepare g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
|
|
|
+ float g_last_val = g_chunk[chunk_size - 1];
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
- sum += attn[i * chunk_size + j] * v_new[j * S_v + d];
|
|
|
- }
|
|
|
- core_attn_out[i * S_v + d] = attn_inter[i * S_v + d] + sum;
|
|
|
- }
|
|
|
+ g_diff_exp[i] = expf(g_last_val - g_chunk[i]);
|
|
|
}
|
|
|
|
|
|
- float core_attn_out_sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size * S_v; i++) {
|
|
|
- core_attn_out_sum += core_attn_out[i];
|
|
|
- }
|
|
|
+ // last_recurrent_state = (
|
|
|
+ // last_recurrent_state * g_last
|
|
|
+ // + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
|
|
+ // )
|
|
|
+ float * new_recurrent_state = (float *) malloc(S_v * S_v * sizeof(float));
|
|
|
|
|
|
- float g_last = g_vals[chunk_size - 1];
|
|
|
- float g_last_exp = expf(g_last);
|
|
|
|
|
|
- // First part: last_recurrent_state * g_last_exp
|
|
|
- for (int64_t k = 0; k < S_k; k++) {
|
|
|
- for (int64_t v = 0; v < S_v; v++) {
|
|
|
- last_recurrent_state[k * S_v + v] *= g_last_exp;
|
|
|
- }
|
|
|
- }
|
|
|
+ delta_update_recurrent_state_f32(head_state_data, g_last, k_chunk, g_diff_exp, v_new,
|
|
|
+ new_recurrent_state, chunk_size, S_v, S_v);
|
|
|
|
|
|
- // Second part: (k_i * (g_last - g).exp()).transpose(-1, -2) @ v_new
|
|
|
- float * k_gated = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- float g_diff_exp = expf(g_last - g_vals[i]);
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
- k_gated[i * S_k + d] = ((float *)src1->data)[k_offset + d * k_stride0 + i * k_stride1] * g_diff_exp;
|
|
|
+
|
|
|
+ // Store the new state
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
|
|
|
+ new_state[state_idx] = new_recurrent_state[i * S_v + j];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Compute k_gated.T @ v_new
|
|
|
- for (int64_t k = 0; k < S_k; k++) {
|
|
|
- for (int64_t v = 0; v < S_v; v++) {
|
|
|
- float sum = 0.0f;
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- sum += k_gated[i * S_k + k] * v_new[i * S_v + v];
|
|
|
- }
|
|
|
- last_recurrent_state[k * S_v + v] += sum;
|
|
|
+ // Update the original state tensor with the new state for the next chunk
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ int64_t state_idx = i * S_v + j;
|
|
|
+ head_state_data[state_idx] = new_recurrent_state[state_idx];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // Recalculate head_state_data to point to the updated state for the next iteration
|
|
|
+ head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
|
|
|
+
|
|
|
// Free temporary memory
|
|
|
free(attn);
|
|
|
- free(value);
|
|
|
- free(k_cumdecay);
|
|
|
free(v_prime);
|
|
|
free(v_new);
|
|
|
free(attn_inter);
|
|
|
- free(k_gated);
|
|
|
+ free(core_attn_out_chunk);
|
|
|
+ free(g_last);
|
|
|
+ free(g_diff_exp);
|
|
|
+ free(mask);
|
|
|
+ free(q_g_exp);
|
|
|
+ free(attn_v_new);
|
|
|
+ free(new_recurrent_state);
|
|
|
+
|
|
|
+ // Free the value and k_cumdecay allocated at the beginning of the loop
|
|
|
+ free(value_chunk);
|
|
|
+ free(k_cumdecay);
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// ggml_compute_forward_rwkv_wkv7
|