|
|
@@ -9,6 +9,7 @@
|
|
|
|
|
|
#include <float.h>
|
|
|
#include <algorithm>
|
|
|
+#include <cmath>
|
|
|
|
|
|
// ggml_compute_forward_dup
|
|
|
|
|
|
@@ -2171,6 +2172,57 @@ void ggml_compute_forward_sum(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// ggml_compute_forward_cumsum
|
|
|
+
|
|
|
+static void ggml_compute_forward_cumsum_f32(
|
|
|
+ const ggml_compute_params * params,
|
|
|
+ ggml_tensor * dst) {
|
|
|
+
|
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
|
+
|
|
|
+ if (params->ith != 0) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
|
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
|
+
|
|
|
+ GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
+
|
|
|
+ GGML_ASSERT(ne0 == ne00);
|
|
|
+ GGML_ASSERT(ne1 == ne01);
|
|
|
+ GGML_ASSERT(ne2 == ne02);
|
|
|
+ GGML_ASSERT(ne3 == ne03);
|
|
|
+
|
|
|
+ for (int64_t i3 = 0; i3 < ne03; i3++) {
|
|
|
+ for (int64_t i2 = 0; i2 < ne02; i2++) {
|
|
|
+ for (int64_t i1 = 0; i1 < ne01; i1++) {
|
|
|
+ float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
|
|
|
+ float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
|
|
|
+ ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_compute_forward_cumsum(
|
|
|
+ const ggml_compute_params * params,
|
|
|
+ ggml_tensor * dst) {
|
|
|
+
|
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
|
+
|
|
|
+ switch (src0->type) {
|
|
|
+ case GGML_TYPE_F32:
|
|
|
+ {
|
|
|
+ ggml_compute_forward_cumsum_f32(params, dst);
|
|
|
+ } break;
|
|
|
+ default:
|
|
|
+ {
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// ggml_compute_forward_sum_rows
|
|
|
|
|
|
static void ggml_compute_forward_sum_rows_f32(
|
|
|
@@ -2917,6 +2969,49 @@ static void ggml_compute_forward_gelu(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// ggml_compute_tri
|
|
|
+
|
|
|
+static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
|
+
|
|
|
+ ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0];
|
|
|
+ float c = *((float *) &(dst->op_params[1]));
|
|
|
+ bool keep_org_val = isnan(c);
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
+ GGML_ASSERT(src0->ne[0] == src0->ne[1]);
|
|
|
+
|
|
|
+ GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
+
|
|
|
+ const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
|
+
|
|
|
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
|
+ const int64_t i03 = ir/(ne02*ne01);
|
|
|
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
|
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
|
+
|
|
|
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
|
+ float * src = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 );
|
|
|
+ ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype);
|
|
|
+ }
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
|
+
|
|
|
+ switch (src0->type) {
|
|
|
+ case GGML_TYPE_F32:
|
|
|
+ {
|
|
|
+ ggml_compute_forward_tri_f32(params, dst);
|
|
|
+ } break;
|
|
|
+ default:
|
|
|
+ {
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// ggml_compute_forward_gelu_erf
|
|
|
|
|
|
static void ggml_compute_forward_gelu_erf_f32(
|
|
|
@@ -10362,8 +10457,596 @@ void ggml_compute_forward_gla(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// ggml_compute_forward_rwkv_wkv7
|
|
|
+// Helper function to compute cumulative sum
|
|
|
+static void ggml_cumsum_f32(const float * x, float * dst, const int64_t n) {
|
|
|
+ float cumsum = 0.0f;
|
|
|
+ for (int64_t i = 0; i < n; i++) {
|
|
|
+ cumsum += x[i];
|
|
|
+ dst[i] = cumsum;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function for matrix multiplication
|
|
|
+static void ggml_matmul_f32(const float * a, const float * b, float * dst,
|
|
|
+ const int64_t m, const int64_t n, const int64_t k) {
|
|
|
+ for (int64_t i = 0; i < m; i++) {
|
|
|
+ for (int64_t j = 0; j < n; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t l = 0; l < k; l++) {
|
|
|
+ sum += a[i * k + l] * b[l * n + j];
|
|
|
+ }
|
|
|
+ dst[i * n + j] = sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to create upper triangular mask
|
|
|
+static void ggml_create_upper_triangular_mask(bool * mask, const int64_t size) {
|
|
|
+ for (int64_t i = 0; i < size; i++) {
|
|
|
+ for (int64_t j = 0; j < size; j++) {
|
|
|
+ mask[i * size + j] = (j >= i); // upper triangular with diagonal
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to compute chunk decay mask
|
|
|
+static void ggml_compute_chunk_decay_mask_f32(const float * g_cumsum, float * decay_mask,
|
|
|
+ const int64_t chunk_size) {
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
+ if (i >= j) { // Only compute for lower triangular (including diagonal)
|
|
|
+ float g_diff = g_cumsum[i] - g_cumsum[j];
|
|
|
+ decay_mask[i * chunk_size + j] = expf(-g_diff);
|
|
|
+ } else {
|
|
|
+ decay_mask[i * chunk_size + j] = 0.0f; // Causal mask
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to compute k_beta @ key.T
|
|
|
+static void ggml_compute_k_beta_key_t_f32(const float * k_beta, const float * key,
|
|
|
+ float * k_beta_key_t,
|
|
|
+ const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t d = 0; d < k_head_dim; d++) {
|
|
|
+ int64_t k_beta_idx = i * k_head_dim + d;
|
|
|
+ int64_t key_idx = j * k_head_dim + d;
|
|
|
+ sum += k_beta[k_beta_idx] * key[key_idx];
|
|
|
+ }
|
|
|
+ k_beta_key_t[i * chunk_size + j] = sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to apply triangular updates
|
|
|
+static void ggml_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
|
|
|
+ for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t k = 0; k < i; k++) {
|
|
|
+ sum += attn[i * chunk_size + k] * attn[k * chunk_size + j];
|
|
|
+ }
|
|
|
+ attn[i * chunk_size + j] += sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to add identity matrix
|
|
|
+static void ggml_add_identity_matrix_f32(float * matrix, const int64_t size) {
|
|
|
+ for (int64_t i = 0; i < size; i++) {
|
|
|
+ matrix[i * size + i] += 1.0f;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to compute value = attn @ v_beta
|
|
|
+static void ggml_compute_value_f32(const float * attn, const float * v_beta,
|
|
|
+ float * value,
|
|
|
+ const int64_t chunk_size, const int64_t v_head_dim) {
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
+ for (int64_t d = 0; d < v_head_dim; d++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
+ int64_t v_beta_idx = j * v_head_dim + d;
|
|
|
+ sum += attn[i * chunk_size + j] * v_beta[v_beta_idx];
|
|
|
+ }
|
|
|
+ value[i * v_head_dim + d] = sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
+static void ggml_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const double * g_exp,
|
|
|
+ float * k_cumdecay,
|
|
|
+ const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
+ for (int64_t d = 0; d < k_head_dim; d++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
+ int64_t k_beta_idx = j * k_head_dim + d;
|
|
|
+ sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * g_exp[j];
|
|
|
+ }
|
|
|
+ k_cumdecay[i * k_head_dim + d] = sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+// Helper functions for delta net computation
|
|
|
+
|
|
|
+// Matrix multiplication helper for delta net
|
|
|
+static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, const int64_t cols_a, const int64_t cols_b,
|
|
|
+ const float * b, float * result) {
|
|
|
+ for (int64_t i = 0; i < rows_a; i++) {
|
|
|
+ for (int64_t j = 0; j < cols_b; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t k = 0; k < cols_a; k++) {
|
|
|
+ int64_t a_idx = i * cols_a + k;
|
|
|
+ int64_t b_idx = k * cols_b + j;
|
|
|
+ sum += a[a_idx] * b[b_idx];
|
|
|
+ }
|
|
|
+ result[i * cols_b + j] = sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
+ const struct ggml_tensor * src0 = dst->src[0]; // q (already normalized and scaled)
|
|
|
+ const struct ggml_tensor * src1 = dst->src[1]; // k (already normalized)
|
|
|
+ const struct ggml_tensor * src2 = dst->src[2]; // v
|
|
|
+ const struct ggml_tensor * src3 = dst->src[3]; // g (cumsum)
|
|
|
+ const struct ggml_tensor * src4 = dst->src[4]; // state
|
|
|
+ const struct ggml_tensor * src5 = dst->src[5]; // decay_mask
|
|
|
+ const struct ggml_tensor * src6 = dst->src[6]; // v_beta
|
|
|
+ const struct ggml_tensor * src7 = dst->src[7]; // k_beta
|
|
|
+ const struct ggml_tensor * src8 = dst->src[8]; // attn
|
|
|
+
|
|
|
+ const int64_t H_v = (int64_t) dst->op_params[0];
|
|
|
+ const int64_t S_k = (int64_t) dst->op_params[1];
|
|
|
+ const int64_t S_v = (int64_t) dst->op_params[2];
|
|
|
+ const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
|
+ const int64_t H_k = H_v;
|
|
|
+ const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
|
+ const int64_t n_seqs = src0->ne[3]; // q tensor has n_seqs in dim 3
|
|
|
+
|
|
|
+ // Add assertions to verify tensor dimensions
|
|
|
+ GGML_ASSERT(src0->ne[3] == n_seqs); // q tensor
|
|
|
+ GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
|
|
|
+ GGML_ASSERT(src2->ne[3] == n_seqs); // v tensor
|
|
|
+ GGML_ASSERT(src3->ne[3] == n_seqs); // g tensor
|
|
|
+ GGML_ASSERT(src4->ne[3] == n_seqs); // beta tensor
|
|
|
+
|
|
|
+ float * dst_data = (float *) dst->data;
|
|
|
+ // Following GLA pattern: output is first part, state is second part
|
|
|
+ float * output = dst_data; // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
|
|
|
+ float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
+
|
|
|
+ const int ith = params->ith;
|
|
|
+ // const int nth = params->nth; // nth is unused
|
|
|
+
|
|
|
+ // For chunked implementation, we process all sequences in thread 0 for simplicity
|
|
|
+ // This can be optimized later to parallelize across sequences
|
|
|
+ if (ith != 0) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Clear output and new state section
|
|
|
+ memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
|
|
|
+
|
|
|
+ // Get tensor data pointers
|
|
|
+ float * state_data = (float *) src4->data;
|
|
|
+ float * decay_mask = (float *) src5->data;
|
|
|
+
|
|
|
+ // Allocate temporary buffers for computation
|
|
|
+ const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
|
+ // The first dimension is the chunk_size, second is head_dim, third is num_heads, fourth is n_seqs
|
|
|
+ // Note: In reference Python implementation, tensors are padded to multiple of chunk_size
|
|
|
+ // but the output only contains the real sequence length, not the padded length
|
|
|
+
|
|
|
+ // Calculate the actual padded sequence length for internal processing
|
|
|
+ const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
|
+ const int64_t total_sequence_length = n_tokens + pad_size;
|
|
|
+ const int64_t n_chunks = (total_sequence_length + chunk_size - 1) / chunk_size; // Ceiling division
|
|
|
+
|
|
|
+ // Temporary buffers for each chunk
|
|
|
+ std::vector<float> attn(chunk_size * chunk_size, 0.0f);
|
|
|
+ std::vector<float> value(chunk_size * S_v, 0.0f);
|
|
|
+ std::vector<float> k_cumdecay(chunk_size * S_k, 0.0f);
|
|
|
+ std::vector<double> g_exp(chunk_size, 0.0f);
|
|
|
+ std::vector<float> g_cumsum(chunk_size, 0.0f);
|
|
|
+ std::vector<float> last_state(S_v * S_v * H_v, 0.0f);
|
|
|
+
|
|
|
+ // Initialize last_state with input state data
|
|
|
+ // State format in GGML: [S_v, S_v * H_v, 1, 1] where S_v * H_v = S_v * num_heads
|
|
|
+ // The state tensor has format [S_v, S_v * H_v, 1, 1] where second dimension is S_v * num_heads
|
|
|
+ // For delta_net, S_k == S_v (both k and v have the same head dimension)
|
|
|
+ for (int64_t h = 0; h < H_v; h++) {
|
|
|
+ for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
+ for (int64_t d2 = 0; d2 < S_v; d2++) {
|
|
|
+ // GGML state index: [d1, d2 + h*S_v, 0, 0] in flattened form
|
|
|
+ int64_t ggml_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
+ // Our computed state index: [d1, d2 + h*S_v]
|
|
|
+ int64_t computed_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
+ last_state[computed_state_idx] = state_data[ggml_state_idx];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Maintain running cumulative sum across all chunks
|
|
|
+ std::vector<float> running_cumsum(n_tokens, 0.0f);
|
|
|
+
|
|
|
+ // Process each chunk
|
|
|
+ for (int64_t chunk_idx = 0; chunk_idx < n_chunks; chunk_idx++) {
|
|
|
+ // Process each head and sequence
|
|
|
+ for (int64_t h = 0; h < H_k; h++) {
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ // Extract chunk data for this head and sequence
|
|
|
+ std::vector<float> q_chunk(chunk_size * S_k);
|
|
|
+ std::vector<float> k_chunk(chunk_size * S_k);
|
|
|
+ std::vector<float> v_chunk(chunk_size * S_v);
|
|
|
+ std::vector<float> v_beta_chunk(chunk_size * S_v);
|
|
|
+ std::vector<float> k_beta_chunk(chunk_size * S_k);
|
|
|
+ std::vector<float> g_chunk(chunk_size);
|
|
|
+
|
|
|
+ // Initialize chunks with zeros for padding
|
|
|
+ std::fill(q_chunk.begin(), q_chunk.end(), 0.0f);
|
|
|
+ std::fill(k_chunk.begin(), k_chunk.end(), 0.0f);
|
|
|
+ std::fill(v_chunk.begin(), v_chunk.end(), 0.0f);
|
|
|
+ std::fill(v_beta_chunk.begin(), v_beta_chunk.end(), 0.0f);
|
|
|
+ std::fill(k_beta_chunk.begin(), k_beta_chunk.end(), 0.0f);
|
|
|
+ std::fill(g_chunk.begin(), g_chunk.end(), 0.0f);
|
|
|
+
|
|
|
+ // Determine actual tokens in this chunk
|
|
|
+ int64_t tokens_in_chunk = std::min(chunk_size, n_tokens - chunk_idx * chunk_size);
|
|
|
+
|
|
|
+ // Copy data for this chunk
|
|
|
+ for (int64_t t = 0; t < tokens_in_chunk; t++) {
|
|
|
+ int64_t actual_pos = chunk_idx * chunk_size + t; // Position in the original sequence
|
|
|
+
|
|
|
+ // Only copy if this position is within the original sequence length
|
|
|
+ if (actual_pos < n_tokens) {
|
|
|
+ // Calculate indices in GGML format [chunk_size, head_dim, num_heads, n_seqs]
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
+ q_chunk[t * S_k + d] = ggml_get_f32_nd(src0, actual_pos, d, h, seq);
|
|
|
+ k_chunk[t * S_k + d] = ggml_get_f32_nd(src1, actual_pos, d, h, seq);
|
|
|
+ k_beta_chunk[t * S_k + d] = ggml_get_f32_nd(src7, actual_pos, d, h, seq);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ v_chunk[t * S_v + d] = ggml_get_f32_nd(src2, actual_pos, d, h, seq);
|
|
|
+ v_beta_chunk[t * S_v + d] = ggml_get_f32_nd(src6, actual_pos, d, h, seq);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (actual_pos <
|
|
|
+ n_tokens) { // Only copy if this position is within the original sequence length
|
|
|
+ // Use the safe GGML function to access tensor values
|
|
|
+ g_chunk[t] = ggml_get_f32_nd(src3, actual_pos, 0, h, seq);
|
|
|
+ } else {
|
|
|
+ // For padded positions, set to 0 (or a default value)
|
|
|
+ g_chunk[t] = 0.0f;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // For padded positions beyond original sequence, set to 0
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
+ q_chunk[t * S_k + d] = 0.0f;
|
|
|
+ k_chunk[t * S_k + d] = 0.0f;
|
|
|
+ k_beta_chunk[t * S_k + d] = 0.0f;
|
|
|
+ }
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ v_chunk[t * S_v + d] = 0.0f;
|
|
|
+ v_beta_chunk[t * S_v + d] = 0.0f;
|
|
|
+ }
|
|
|
+ g_chunk[t] = 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // In Python, cumsum is applied to each chunk separately after reshaping
|
|
|
+ // So we need to compute cumsum within this chunk only
|
|
|
+
|
|
|
+ // g_chunk already contains the cumsum values from src3 (g_cumsum), so use them directly
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only for actual tokens, not full chunk size
|
|
|
+ g_cumsum[i] = g_chunk[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ // For padded positions, set cumsum values to 0
|
|
|
+ for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
|
|
|
+ g_cumsum[i] = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Compute g_exp from cumulative sums (like Python: g.cumsum().exp())
|
|
|
+ // Apply numerical stability to prevent underflow for very negative values
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only for actual tokens, not full chunk size
|
|
|
+ // Use double precision for exponential to avoid overflow/underflow
|
|
|
+ // Apply lower bound to prevent extreme underflow - exp(-50) is about 1.9e-22
|
|
|
+ double g_val = (double) g_cumsum[i];
|
|
|
+ double g_exp_double = exp(g_val);
|
|
|
+ g_exp[i] = g_exp_double;
|
|
|
+ }
|
|
|
+
|
|
|
+ // For padded positions, set exp values to 0
|
|
|
+ for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
|
|
|
+ g_exp[i] = 0.0f;
|
|
|
+ }
|
|
|
+ // Step 1: Compute k_beta @ key.T (this corresponds to the Python: k_beta @ key.transpose(-1, -2))
|
|
|
+ // Only compute for actual tokens in chunk
|
|
|
+ ggml_compute_k_beta_key_t_f32(k_beta_chunk.data(), k_chunk.data(), attn.data(), tokens_in_chunk,
|
|
|
+ S_k); // Use actual tokens, not full chunk size
|
|
|
+
|
|
|
+ // Apply precomputed decay mask from src5 and negate the result (like Python: -(...))
|
|
|
+ // The decay mask is computed in ggml_delta_net in ggml.c and passed as src5
|
|
|
+ // Apply the precomputed decay mask from src5 (decay_mask tensor)
|
|
|
+ // The decay_mask tensor now contains exp(g_cumsum[j] - g_cumsum[i]) values
|
|
|
+ // where g_cumsum[j] - g_cumsum[i] is computed in the main function
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only for actual tokens
|
|
|
+ for (int64_t j = 0; j < tokens_in_chunk; j++) { // Only for actual tokens
|
|
|
+ // Get decay mask value from precomputed tensor
|
|
|
+ // src5 decay_mask has shape [chunk_size, chunk_size, H_k, n_seqs] in GGML format
|
|
|
+ // Format: [i_pos, j_pos, head, seq] - represents exp(g_cumsum[j] - g_cumsum[i])
|
|
|
+ float decay_val = ggml_get_f32_nd(
|
|
|
+ src5, i, j, h, seq); // [i, j, h, seq] to get exp(g_cumsum[j] - g_cumsum[i]) for head h
|
|
|
+ if (j <= i) { // Only apply to lower triangular part (i >= j)
|
|
|
+ // The decay_val already contains exp(g_cumsum[j] - g_cumsum[i]), no need for additional exponential
|
|
|
+ // Apply the decay mask and negate (like Python: -((k_beta @ key.T) * decay_mask))
|
|
|
+ attn[i * chunk_size + j] = -attn[i * chunk_size + j] * decay_val;
|
|
|
+ } else {
|
|
|
+ attn[i * chunk_size + j] =
|
|
|
+ 0.0f; // Zero out upper triangular part (like Python: masked_fill(mask, 0))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Step 2: Apply triangular updates (equivalent to Python's complex triangular update)
|
|
|
+ // Python: for i in range(1, chunk_size):
|
|
|
+ // row = attn[..., i, :i].clone() // row = attn[i, 0:i]
|
|
|
+ // sub = attn[..., :i, :i].clone() // sub = attn[0:i, 0:i]
|
|
|
+ // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
+ // This means: new_attn[i, j] = old_attn[i, j] + sum_k(old_attn[i, k] * old_attn[k, j]) for k < i
|
|
|
+ for (int64_t i = 1; i < tokens_in_chunk; i++) {
|
|
|
+ // Store the original row values to avoid using updated values in computation
|
|
|
+ std::vector<float> original_row(i);
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
+ original_row[j] = attn[i * tokens_in_chunk + j]; // Use tokens_in_chunk for indexing
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t k = 0; k < i; k++) {
|
|
|
+ // This implements: sum over k of (original_row[k] * sub[k, j])
|
|
|
+ // Where sub[k, j] is attn[k, j] (the original value before updates)
|
|
|
+ sum += original_row[k] * attn[k * tokens_in_chunk + j]; // Use tokens_in_chunk for indexing
|
|
|
+ }
|
|
|
+ // The new value is: original_value + matrix_mult_result
|
|
|
+ attn[i * tokens_in_chunk + j] = original_row[j] + sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Step 3: Add identity matrix (equivalent to Python's: attn = attn + torch.eye(...))
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) {
|
|
|
+ attn[i * tokens_in_chunk + i] += 1.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Step 4: Compute value = attn @ v_beta
|
|
|
+ ggml_compute_value_f32(attn.data(), v_beta_chunk.data(), value.data(), tokens_in_chunk,
|
|
|
+ S_v); // Use actual tokens, not full chunk size
|
|
|
+
|
|
|
+ // Step 5: Compute k_cumdecay = attn @ (k_beta * g_exp)
|
|
|
+ ggml_compute_k_cumdecay_f32(attn.data(), k_beta_chunk.data(), g_exp.data(), k_cumdecay.data(),
|
|
|
+ tokens_in_chunk, S_k); // Use actual tokens, not full chunk size
|
|
|
|
|
|
+ // Step 6: Compute core attention output for this chunk
|
|
|
+ // First, compute v_new for all tokens in the chunk
|
|
|
+ std::vector<float> v_new_chunk(tokens_in_chunk * S_v);
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) {
|
|
|
+ // v_prime = k_cumdecay @ last_state
|
|
|
+ // k_cumdecay[i] is [S_k], last_state for head h is [S_k, S_v]
|
|
|
+ std::vector<float> v_prime(S_v, 0.0f);
|
|
|
+ for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
+ for (int64_t d2 = 0; d2 < S_k; d2++) {
|
|
|
+ // State index: [d2, d1 + h*S_v] in GGML format
|
|
|
+ int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
|
|
|
+ v_prime[d1] += k_cumdecay[i * S_k + d2] * last_state[state_idx];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // v_new = v_i - v_prime
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ v_new_chunk[i * S_v + d] = value[i * S_v + d] - v_prime[d];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now process each token in the chunk to compute output
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) {
|
|
|
+ // q_i @ k_i.T * decay_mask
|
|
|
+ std::vector<float> q_k_attn(chunk_size);
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t d = 0; d < S_k; d++) {
|
|
|
+ sum += q_chunk[i * S_k + d] * k_chunk[j * S_k + d];
|
|
|
+ }
|
|
|
+ // Apply decay mask - use the precomputed decay mask from src5 tensor
|
|
|
+ if (j <= i) { // Only apply to lower triangular part (i >= j)
|
|
|
+ float decay_val = ggml_get_f32_nd(
|
|
|
+ src5, i, j, h, seq); // [i, j, h, seq] to get exp(g_cumsum[i] - g_cumsum[j]) for head h
|
|
|
+ q_k_attn[j] = sum * decay_val;
|
|
|
+ } else {
|
|
|
+ q_k_attn[j] = 0.0f; // Zero out upper triangular part (i < j)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // attn_inter = q_i * g_exp @ last_state
|
|
|
+ // q_chunk[i] is [S_k], g_exp[i] is scalar, last_state for head h is [S_k, S_v]
|
|
|
+ std::vector<float> attn_inter(S_v, 0.0f);
|
|
|
+ for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
+ for (int64_t d2 = 0; d2 < S_k; d2++) {
|
|
|
+ // State index: [d2, d1 + h*S_v] in GGML format
|
|
|
+ int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
|
|
|
+ // Use double precision for the computation and then cast to float
|
|
|
+ double temp_result =
|
|
|
+ (double) q_chunk[i * S_k + d2] * g_exp[i] * (double) last_state[state_idx];
|
|
|
+ attn_inter[d1] += (float) temp_result;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // core_attn_out = attn_inter + attn @ v_new
|
|
|
+ // We need to use the attention matrix computed for this position (i)
|
|
|
+ // The attn matrix was computed earlier in the chunk processing
|
|
|
+ // attn @ v_new where attn is [chunk_size, chunk_size] and v_new is [chunk_size, S_v]
|
|
|
+ // For token i, we want sum_j(attn[i, j] * v_new[j, :])
|
|
|
+ std::vector<float> attn_v_new(S_v, 0.0f);
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ for (int64_t j = 0; j < tokens_in_chunk; j++) { // Only process actual tokens
|
|
|
+ // Use the attention matrix that was computed for position i
|
|
|
+ // attn[i * chunk_size + j] is the attention from position i to j
|
|
|
+ // v_new_chunk[j * S_v + d] is the v_new value for token j, dimension d
|
|
|
+ attn_v_new[d] += attn[i * chunk_size + j] * v_new_chunk[j * S_v + d];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Store output - only store for the original sequence length (not the padded part)
|
|
|
+ int64_t global_pos =
|
|
|
+ chunk_idx * chunk_size + i; // Convert local chunk position to global sequence position
|
|
|
+ if (global_pos < n_tokens) { // Make sure we don't exceed the original sequence length
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ // Output tensor is [S_v * H_v * n_tokens] for single sequence (n_seqs=1)
|
|
|
+ // Indexing: [dim_idx + head_idx*S_v + pos_idx*S_v*H_v]
|
|
|
+ int64_t ggml_idx = d + h * S_v + global_pos * S_v * H_v;
|
|
|
+ output[ggml_idx] = attn_inter[d] + attn_v_new[d];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Step 7: Update last_recurrent_state
|
|
|
+ std::vector<float> new_state_vec(S_v * S_v * H_v);
|
|
|
+
|
|
|
+ // Update running cumulative sum with current chunk's values
|
|
|
+ float prev_cumsum = 0.0f; // Cumulative sum from all previous chunks
|
|
|
+ if (chunk_idx > 0) {
|
|
|
+ // Get the cumulative sum of the last token from the previous chunk
|
|
|
+ int64_t prev_chunk_last_token = std::min(chunk_size, n_tokens - (chunk_idx - 1) * chunk_size) - 1;
|
|
|
+ if (prev_chunk_last_token >= 0) {
|
|
|
+ prev_cumsum = running_cumsum[(chunk_idx - 1) * chunk_size + prev_chunk_last_token];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update running_cumsum for tokens in this chunk
|
|
|
+ for (int64_t t = 0; t < tokens_in_chunk; t++) {
|
|
|
+ int64_t global_pos = chunk_idx * chunk_size + t;
|
|
|
+ if (global_pos < n_tokens) {
|
|
|
+ running_cumsum[global_pos] = prev_cumsum + g_cumsum[t];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Find the last token position in the current chunk (not the entire sequence)
|
|
|
+ int64_t last_pos_in_chunk =
|
|
|
+ std::min((chunk_idx + 1) * chunk_size, n_tokens) - 1; // Last actual token in this chunk
|
|
|
+ if (last_pos_in_chunk >= chunk_idx * chunk_size && last_pos_in_chunk < n_tokens) {
|
|
|
+ float g_last =
|
|
|
+ running_cumsum[last_pos_in_chunk]; // Use the last token's cumulative sum in this chunk
|
|
|
+ // Use double precision for exponential to avoid overflow/underflow
|
|
|
+ double g_last_exp_double = exp((double) g_last);
|
|
|
+ float g_last_exp = (float) g_last_exp_double;
|
|
|
+
|
|
|
+ // last_state * g_exp[last]
|
|
|
+ for (int64_t i = 0; i < S_k; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ // State index: [i, j + h*S_v] in GGML format
|
|
|
+ int64_t state_idx = i * (S_v * H_v) + (j + h * S_v);
|
|
|
+ new_state_vec[i * (S_v * H_v) + (j + h * S_v)] = last_state[state_idx] * g_last_exp;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add (k_i * (g_last - g_i).exp()).T @ v_new
|
|
|
+ // This should be: (k_chunk * g_diff_exp).T @ v_new_chunk
|
|
|
+ // where k_chunk is [chunk_size, S_k], v_new_chunk is [chunk_size, S_v]
|
|
|
+ // result is [S_k, S_v]
|
|
|
+
|
|
|
+ // First compute v_new for all positions in the chunk
|
|
|
+ std::vector<float> v_new_chunk(chunk_size * S_v);
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only process actual tokens, not full chunk size
|
|
|
+ for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
+ // Recompute v_prime for this position
|
|
|
+ float v_prime = 0.0f;
|
|
|
+ for (int64_t d2 = 0; d2 < S_k; d2++) {
|
|
|
+ // State index: [d2, d1 + h*S_v] in GGML format
|
|
|
+ int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
|
|
|
+ float k_val = k_cumdecay[i * S_k + d2];
|
|
|
+ float s_val = last_state[state_idx];
|
|
|
+ v_prime += k_val * s_val;
|
|
|
+ }
|
|
|
+ v_new_chunk[i * S_v + d1] = value[i * S_v + d1] - v_prime;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now compute (k_chunk * g_diff_exp).T @ v_new_chunk
|
|
|
+ // This is a matrix multiplication: [S_k, chunk_size] @ [chunk_size, S_v] = [S_k, S_v]
|
|
|
+ // Only process the original sequence length, not the padded chunk size
|
|
|
+ // In the Python reference, this is: (k_i * g_diff_exp).transpose(-1, -2) @ v_new
|
|
|
+ // where g_diff_exp = torch.exp(g_last - g) and g_last = g[-1] (last token in chunk)
|
|
|
+ for (int64_t d1 = 0; d1 < S_k; d1++) {
|
|
|
+ for (int64_t d2 = 0; d2 < S_v; d2++) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int64_t i = 0; i < tokens_in_chunk; i++) { // Only process actual tokens
|
|
|
+ // Get g values for the current chunk from the cumsum tensor (src3)
|
|
|
+ // For state update: g_last (last token in chunk) - g_current (current token)
|
|
|
+ // g tensor has shape [GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs] in GGML format after cumsum and reshaping
|
|
|
+
|
|
|
+ // Access g_cumsum for current position in chunk - need to access the original g tensor before cumsum
|
|
|
+ // The g_cumsum tensor is src3, but we need the original g values for the diff computation
|
|
|
+ // Actually, we need to access g values that were cumsummed to compute the diff
|
|
|
+
|
|
|
+ // Get the original g_cumsum values for current and last token in the chunk
|
|
|
+ // g_cumsum values are stored in src3, which was reshaped from [chunk_size, 1, H_v, n_seqs] to [chunk_size, 1, H_v, n_seqs]
|
|
|
+ float g_current = g_cumsum[i]; // Use the g_cumsum computed earlier in this chunk
|
|
|
+ float g_last =
|
|
|
+ g_cumsum[tokens_in_chunk - 1]; // Use the last token's cumsum in this chunk
|
|
|
+
|
|
|
+ float g_diff = g_last - g_current;
|
|
|
+ float g_diff_exp;
|
|
|
+ // Use double precision for exponential to avoid overflow/underflow
|
|
|
+ // For numerical stability, if g_diff is very negative, exp(g_diff) will be very small
|
|
|
+ if (g_diff < -50.0f) {
|
|
|
+ g_diff_exp = 0.0f; // Set to zero to avoid underflow
|
|
|
+ } else {
|
|
|
+ double g_diff_exp_double = exp((double) g_diff);
|
|
|
+ g_diff_exp = (float) g_diff_exp_double;
|
|
|
+ }
|
|
|
+ sum += k_chunk[i * S_k + d1] * g_diff_exp * v_new_chunk[i * S_v + d2];
|
|
|
+ }
|
|
|
+ // State index: [d1, d2 + h*S_v] in GGML format
|
|
|
+ int64_t state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
+ new_state_vec[state_idx] += sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update last_state
|
|
|
+ for (int64_t i = 0; i < S_k; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ // State index: [i, j + h*S_v] in GGML format
|
|
|
+ int64_t state_idx = i * (S_v * H_v) + (j + h * S_v);
|
|
|
+ last_state[state_idx] = new_state_vec[state_idx];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Copy the final state to the output tensor in the correct GGML layout
|
|
|
+ // GGML expects state layout: [d1, d2 + h*head_dim]
|
|
|
+ for (int64_t h = 0; h < H_v; h++) {
|
|
|
+ for (int64_t d1 = 0; d1 < S_v; d1++) {
|
|
|
+ for (int64_t d2 = 0; d2 < S_v; d2++) {
|
|
|
+ // GGML state index: [d1, d2 + h*head_dim]
|
|
|
+ int64_t ggml_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
+ // Our computed state index: [d1, d2 + h*S_v]
|
|
|
+ int64_t computed_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
|
|
|
+ float val = last_state[computed_state_idx];
|
|
|
+ new_state[ggml_state_idx] = val;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ggml_compute_forward_rwkv_wkv7
|
|
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
const ggml_compute_params * params,
|
|
|
ggml_tensor * dst) {
|