|
@@ -10529,7 +10529,69 @@ static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * k
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Helper function to apply triangular updates
|
|
|
|
|
|
|
+// Helper function to apply triangular updates to entire chunk (all sequences and heads)
|
|
|
|
|
+static void delta_apply_triangular_updates_chunk_f32(float * attn, const int64_t chunk_size,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+
|
|
|
|
|
+ // Apply triangular updates following the Python reference exactly:
|
|
|
|
|
+ // for i in range(1, chunk_size):
|
|
|
|
|
+ // row = attn[..., i, :i].clone()
|
|
|
|
|
+ // sub = attn[..., :i, :i].clone()
|
|
|
|
|
+ // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
+ for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
|
|
+ // Create temporary storage for row and sub to avoid modifying during computation
|
|
|
|
|
+ float * row = (float *) malloc(i * sizeof(float));
|
|
|
|
|
+ float * sub = (float *) malloc(i * i * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Copy row = attn[..., i, :i]
|
|
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
+ row[j] = attn_ptr[i * chunk_size + j];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Copy sub = attn[..., :i, :i]
|
|
|
|
|
+ for (int64_t k = 0; k < i; k++) {
|
|
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
+ sub[k * i + j] = attn_ptr[k * chunk_size + j];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Compute updates for each j in :i
|
|
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
+ // Compute (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
+ float sum_val = 0.0f;
|
|
|
|
|
+ for (int64_t k = 0; k < i; k++) {
|
|
|
|
|
+ sum_val += row[k] * sub[k * i + j];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Update: attn[..., i, j] = row[j] + sum_val
|
|
|
|
|
+ attn_ptr[i * chunk_size + j] = row[j] + sum_val;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ free(row);
|
|
|
|
|
+ free(sub);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Helper function to add identity matrix to entire chunk (all sequences and heads)
|
|
|
|
|
+static void delta_add_identity_matrix_chunk_f32(float * matrix, const int64_t chunk_size,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ // Add identity matrix directly
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ matrix_ptr[i * chunk_size + i] += 1.0f;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Helper function to apply triangular updates (original version for individual matrices)
|
|
|
static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
|
|
static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
|
|
|
for (int64_t i = 1; i < chunk_size; i++) {
|
|
for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < i; j++) {
|
|
for (int64_t j = 0; j < i; j++) {
|
|
@@ -10542,40 +10604,41 @@ static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Helper function to add identity matrix
|
|
|
|
|
|
|
+// Helper function to add identity matrix (original version for individual matrices)
|
|
|
static void delta_add_identity_matrix_f32(float * matrix, const int64_t size) {
|
|
static void delta_add_identity_matrix_f32(float * matrix, const int64_t size) {
|
|
|
for (int64_t i = 0; i < size; i++) {
|
|
for (int64_t i = 0; i < size; i++) {
|
|
|
matrix[i * size + i] += 1.0f;
|
|
matrix[i * size + i] += 1.0f;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Helper function to compute value = attn @ v_beta
|
|
|
|
|
-static void delta_compute_value_f32(const float * attn, const float * v_beta,
|
|
|
|
|
- float * value,
|
|
|
|
|
- const int64_t chunk_size, const int64_t v_head_dim) {
|
|
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
- for (int64_t d = 0; d < v_head_dim; d++) {
|
|
|
|
|
- float sum = 0.0f;
|
|
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
- int64_t v_beta_idx = j * v_head_dim + d;
|
|
|
|
|
- sum += attn[i * chunk_size + j] * v_beta[v_beta_idx];
|
|
|
|
|
- }
|
|
|
|
|
- value[i * v_head_dim + d] = sum;
|
|
|
|
|
|
|
+static void delta_compute_value_f32(const float * attn,
|
|
|
|
|
+ const float * v_beta,
|
|
|
|
|
+ float * value,
|
|
|
|
|
+ const int64_t chunk_size,
|
|
|
|
|
+ const int64_t v_head_dim,
|
|
|
|
|
+ const int64_t n_heads,
|
|
|
|
|
+ const int64_t n_seqs) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < n_heads; head++) {
|
|
|
|
|
+ delta_matmul_f32(
|
|
|
|
|
+ attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * head,
|
|
|
|
|
+ v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
|
|
|
|
|
+ value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
|
|
|
|
|
+ chunk_size, v_head_dim, chunk_size);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
|
|
|
|
+// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) for single head/sequence
|
|
|
static void delta_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const float * g,
|
|
static void delta_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const float * g,
|
|
|
- float * k_cumdecay, const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
|
|
|
|
+ float * k_cumdecay, const int64_t chunk_size, const int64_t k_head_dim) {
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < k_head_dim; d++) {
|
|
|
|
|
|
|
+ for (int64_t j = 0; j < k_head_dim; j++) {
|
|
|
float sum = 0.0f;
|
|
float sum = 0.0f;
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
- int64_t k_beta_idx = j * k_head_dim + d;
|
|
|
|
|
- sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * expf(g[j]);
|
|
|
|
|
|
|
+ for (int64_t k = 0; k < chunk_size; k++) {
|
|
|
|
|
+ sum += attn[i * chunk_size + k] * k_beta[k * k_head_dim + j] * expf(g[k]);
|
|
|
}
|
|
}
|
|
|
- k_cumdecay[i * k_head_dim + d] = sum;
|
|
|
|
|
|
|
+ k_cumdecay[i * k_head_dim + j] = sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -10625,7 +10688,9 @@ static void delta_matmul_state_f32(const float * a, const float * state, float *
|
|
|
for (int64_t k = 0; k < cols_a; k++) {
|
|
for (int64_t k = 0; k < cols_a; k++) {
|
|
|
int64_t a_idx = i * cols_a + k;
|
|
int64_t a_idx = i * cols_a + k;
|
|
|
int64_t state_idx = k * cols_state + j;
|
|
int64_t state_idx = k * cols_state + j;
|
|
|
- sum += a[a_idx] * state[state_idx];
|
|
|
|
|
|
|
+ float a_val = a[a_idx];
|
|
|
|
|
+ float state_val = state[state_idx];
|
|
|
|
|
+ sum += a_val * state_val;
|
|
|
}
|
|
}
|
|
|
dst[i * cols_state + j] = sum;
|
|
dst[i * cols_state + j] = sum;
|
|
|
}
|
|
}
|
|
@@ -10670,6 +10735,108 @@ static void delta_update_recurrent_state_f32(const float * last_state, const flo
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// Helper function to compute q_i @ k_i.transpose(-1, -2) * decay_mask and apply mask for entire chunk
|
|
|
|
|
+static void delta_compute_q_k_attn_chunk_f32(const float * q, const float * k, const float * decay_mask,
|
|
|
|
|
+ float * attn, const bool * mask,
|
|
|
|
|
+ const int64_t chunk_size, const int64_t head_dim,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ const float * q_ptr = q + seq * (chunk_size * head_dim * H_v) + head * (chunk_size * head_dim);
|
|
|
|
|
+ const float * k_ptr = k + seq * (chunk_size * head_dim * H_v) + head * (chunk_size * head_dim);
|
|
|
|
|
+ const float * decay_mask_ptr = decay_mask + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ delta_compute_q_k_attn_f32(q_ptr, k_ptr, decay_mask_ptr, attn_ptr, mask, chunk_size, head_dim);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Helper function for matrix multiplication with state tensors for entire chunk
|
|
|
|
|
+static void delta_matmul_state_chunk_f32(const float * a, const float * state, float * dst,
|
|
|
|
|
+ const int64_t rows_a, const int64_t cols_a, const int64_t cols_state,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ const float * a_ptr = a + seq * (rows_a * cols_a * H_v) + head * (rows_a * cols_a);
|
|
|
|
|
+ const float * state_ptr = state + seq * (cols_a * cols_state * H_v) + head * (cols_a * cols_state);
|
|
|
|
|
+ float * dst_ptr = dst + seq * (rows_a * cols_state * H_v) + head * (rows_a * cols_state);
|
|
|
|
|
+ delta_matmul_state_f32(a_ptr, state_ptr, dst_ptr, rows_a, cols_a, cols_state);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Helper function to update recurrent state for entire chunk
|
|
|
|
|
+static void delta_update_recurrent_state_chunk_f32(const float * state, const float * g_last,
|
|
|
|
|
+ const float * k, const float * g_diff_exp, const float * v_new, float * new_state,
|
|
|
|
|
+ const int64_t chunk_size, const int64_t k_head_dim, const int64_t v_head_dim,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ const float * state_ptr = state + seq * (k_head_dim * v_head_dim * H_v) + head * (k_head_dim * v_head_dim);
|
|
|
|
|
+ const float * k_ptr = k + seq * (chunk_size * k_head_dim * H_v) + head * (chunk_size * k_head_dim);
|
|
|
|
|
+ const float * g_diff_exp_ptr = g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
|
|
+ const float * v_new_ptr = v_new + seq * (chunk_size * v_head_dim * H_v) + head * (chunk_size * v_head_dim);
|
|
|
|
|
+ float * new_state_ptr = new_state + seq * (k_head_dim * v_head_dim * H_v) + head * (k_head_dim * v_head_dim);
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t i = 0; i < k_head_dim; i++) {
|
|
|
|
|
+ for (int64_t j = 0; j < v_head_dim; j++) {
|
|
|
|
|
+ int64_t state_idx = i * v_head_dim + j;
|
|
|
|
|
+
|
|
|
|
|
+ // last_recurrent_state * g_last
|
|
|
|
|
+ float term1 = state_ptr[state_idx] * g_last[seq * H_v + head];
|
|
|
|
|
+
|
|
|
|
|
+ // (k_i * g_diff_exp).transpose(-1, -2) @ v_new
|
|
|
|
|
+ float term2 = 0.0f;
|
|
|
|
|
+ for (int64_t k = 0; k < chunk_size; k++) {
|
|
|
|
|
+ int64_t k_idx = k * k_head_dim + i;
|
|
|
|
|
+ int64_t v_idx = k * v_head_dim + j;
|
|
|
|
|
+ term2 += k_ptr[k_idx] * g_diff_exp_ptr[k] * v_new_ptr[v_idx];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ new_state_ptr[state_idx] = term1 + term2;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Helper function for element-wise tensor subtraction for entire chunk
|
|
|
|
|
+static void delta_tensor_subtract_chunk_f32(const float * a, const float * b, float * dst, const int64_t size,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ const float * a_ptr = a + seq * (size * H_v) + head * size;
|
|
|
|
|
+ const float * b_ptr = b + seq * (size * H_v) + head * size;
|
|
|
|
|
+ float * dst_ptr = dst + seq * (size * H_v) + head * size;
|
|
|
|
|
+ delta_tensor_subtract_f32(a_ptr, b_ptr, dst_ptr, size);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Helper function for element-wise tensor addition for entire chunk
|
|
|
|
|
+static void delta_tensor_add_chunk_f32(const float * a, const float * b, float * dst, const int64_t size,
|
|
|
|
|
+ const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ const float * a_ptr = a + seq * (size * H_v) + head * size;
|
|
|
|
|
+ const float * b_ptr = b + seq * (size * H_v) + head * size;
|
|
|
|
|
+ float * dst_ptr = dst + seq * (size * H_v) + head * size;
|
|
|
|
|
+ delta_tensor_add_f32(a_ptr, b_ptr, dst_ptr, size);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
|
|
|
|
|
+ GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n",
|
|
|
|
|
+ name, token, data[0], data[1], data[2], data[3], data[4]);
|
|
|
|
|
+ double sum = 0.0;
|
|
|
|
|
+ for (unsigned int i = 0; i < size; i++) {
|
|
|
|
|
+ sum += data[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ GGML_LOG_INFO("total elements: %ld, sum = %.10f\n", size, sum);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
const struct ggml_tensor * src0 = dst->src[0]; // q (already normalized and scaled)
|
|
const struct ggml_tensor * src0 = dst->src[0]; // q (already normalized and scaled)
|
|
|
const struct ggml_tensor * src1 = dst->src[1]; // k (already normalized)
|
|
const struct ggml_tensor * src1 = dst->src[1]; // k (already normalized)
|
|
@@ -10682,7 +10849,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
const struct ggml_tensor * src8 = dst->src[8]; // attn
|
|
const struct ggml_tensor * src8 = dst->src[8]; // attn
|
|
|
|
|
|
|
|
const int64_t H_v = (int64_t) dst->op_params[0];
|
|
const int64_t H_v = (int64_t) dst->op_params[0];
|
|
|
- const int64_t S_k = (int64_t) dst->op_params[1];
|
|
|
|
|
const int64_t S_v = (int64_t) dst->op_params[2];
|
|
const int64_t S_v = (int64_t) dst->op_params[2];
|
|
|
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
|
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
@@ -10698,15 +10864,17 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
|
|
float * dst_data = (float *) dst->data;
|
|
float * dst_data = (float *) dst->data;
|
|
|
// Following GLA pattern: output is first part, state is second part
|
|
// Following GLA pattern: output is first part, state is second part
|
|
|
- float * output = dst_data; // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
|
|
|
|
|
- float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
|
|
|
|
+ float * output = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
|
|
|
|
|
+ float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
|
|
|
|
|
const int ith = params->ith;
|
|
const int ith = params->ith;
|
|
|
- const int nth = params->nth; // nth is unused
|
|
|
|
|
|
|
+ // const int nth = params->nth; // nth is unused
|
|
|
|
|
|
|
|
// Clear output and new state section
|
|
// Clear output and new state section
|
|
|
if (ith == 0) {
|
|
if (ith == 0) {
|
|
|
memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
|
|
memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
|
|
|
|
|
+ } else {
|
|
|
|
|
+ return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Calculate chunk size
|
|
// Calculate chunk size
|
|
@@ -10714,16 +10882,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
|
|
const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
|
|
const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
|
|
|
|
|
|
|
|
- // Apply triangular updates to the precomputed attention matrix
|
|
|
|
|
- float * attn_data = (float *) src8->data;
|
|
|
|
|
- float * v_beta_data = (float *) src6->data;
|
|
|
|
|
- float * k_beta_data = (float *) src7->data;
|
|
|
|
|
- float * g_data = (float *) src3->data;
|
|
|
|
|
- float * q_data = (float *) src0->data;
|
|
|
|
|
- float * k_data = (float *) src1->data;
|
|
|
|
|
- //float * v_data = (float *) src2->data;
|
|
|
|
|
float * state_data = (float *) src4->data;
|
|
float * state_data = (float *) src4->data;
|
|
|
- float * decay_mask_data = (float *) src5->data;
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
@@ -10735,161 +10894,347 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
GGML_ASSERT(ggml_is_contiguous(src7));
|
|
GGML_ASSERT(ggml_is_contiguous(src7));
|
|
|
GGML_ASSERT(ggml_is_contiguous(src8));
|
|
GGML_ASSERT(ggml_is_contiguous(src8));
|
|
|
|
|
|
|
|
- int64_t total_params = n_seqs * H_v * num_chunks;
|
|
|
|
|
- int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
|
|
|
|
|
|
|
+ // int64_t total_params = n_seqs * H_v * num_chunks;
|
|
|
|
|
+ // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
|
|
|
|
|
+
|
|
|
|
|
+ // Create helper lambda for state tensor access
|
|
|
|
|
+ const auto state_ptr = [state_data, src4] (int64_t seq, int64_t head, int64_t i, int64_t j) {
|
|
|
|
|
+ return state_data + (j * src4->nb[0] / sizeof(float)) + (i * src4->nb[1] / sizeof(float)) +
|
|
|
|
|
+ (head * src4->nb[2] / sizeof(float)) + (seq * src4->nb[3] / sizeof(float));
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ float * attn = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * value = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * k_cumdecay = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
|
|
|
|
|
+ float * g = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Create upper triangular mask for causal attention (exclude diagonal)
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ mask[i * chunk_size + j] = (j > i); // True for upper triangular (excluding diagonal)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Make a copy of the attention tensor and the gate cumsum tensor
|
|
|
|
|
+ memcpy(attn, src8->data, ggml_nbytes(src8));
|
|
|
|
|
+ memcpy(g, src3->data, ggml_nbytes(src3));
|
|
|
|
|
+
|
|
|
|
|
+ // Prepare the initial attention matrix with triangular updates and identity (for entire chunks)
|
|
|
|
|
+ // This corresponds to the reference implementation:
|
|
|
|
|
+ // for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
+ // attn = attn + torch.eye(chunk_size)
|
|
|
|
|
+ delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v);
|
|
|
|
|
+ delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v);
|
|
|
|
|
|
|
|
|
|
+ // Compute value = attn @ v_beta
|
|
|
|
|
+ delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs);
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
|
|
|
|
|
- int64_t tidx = seq * (H_v * num_chunks) + head * num_chunks + chunk;
|
|
|
|
|
- if (tidx < ith * per_thread || tidx >= (ith + 1) * per_thread) {
|
|
|
|
|
- continue; // not our thread;
|
|
|
|
|
- }
|
|
|
|
|
- float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
|
|
- float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
- float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
- delta_apply_triangular_updates_f32(attn_data_for_chs, chunk_size);
|
|
|
|
|
- delta_add_identity_matrix_f32(attn_data_for_chs, chunk_size);
|
|
|
|
|
- // Calculate the correct v_beta and k_beta pointers for this head and sequence
|
|
|
|
|
- float * v_beta_chunk = v_beta_data + (src6->nb[3] / sizeof(float)) * seq + (src6->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
|
|
- float * k_beta_chunk = k_beta_data + (src7->nb[3] / sizeof(float)) * seq + (src7->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
|
|
- // The g tensor has dimensions [8, 64, 2, 1] = [features, tokens, heads, sequences]
|
|
|
|
|
- // We need to access the correct head data
|
|
|
|
|
- // For each head, we need to access the correct feature for all tokens in the chunk
|
|
|
|
|
- // Let's try accessing feature index chunk (since we have 8 features and chunk=0)
|
|
|
|
|
- float * g_chunk = g_data + (src3->nb[3] / sizeof(float)) * seq + (src3->nb[2] / sizeof(float)) * head + (src3->nb[1] / sizeof(float)) * (chunk * chunk_size);
|
|
|
|
|
- delta_compute_value_f32(attn_data_for_chs, v_beta_chunk, value_chunk, chunk_size, S_v);
|
|
|
|
|
- delta_compute_k_cumdecay_f32(attn_data_for_chs, k_beta_chunk, g_chunk, k_cumdecay, chunk_size, S_k);
|
|
|
|
|
- // Now compute the per-chunk-specific part (corresponding to the inner loop in Python)
|
|
|
|
|
- float * q_chunk = q_data + (src0->nb[3] / sizeof(float)) * seq + (src0->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
|
|
- float * k_chunk = k_data + (src1->nb[3] / sizeof(float)) * seq + (src1->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
|
|
- float * decay_mask_chunk = decay_mask_data + (src5->nb[3] / sizeof(float)) * seq + (src5->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
|
|
- float * k_cumdecay_chunk = k_cumdecay + (S_v * chunk_size * H_v) * seq + (S_v * chunk_size) * head;
|
|
|
|
|
-
|
|
|
|
|
- // Allocate temporary variables for the loop
|
|
|
|
|
- float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
|
|
|
|
|
- float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
- float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
- float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
- float * core_attn_out_chunk = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
- float * g_last = (float *) malloc(sizeof(float));
|
|
|
|
|
- float * g_diff_exp = (float *) malloc(chunk_size * sizeof(float));
|
|
|
|
|
- bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
|
|
|
|
|
-
|
|
|
|
|
- // Create upper triangular mask for causal attention (exclude diagonal)
|
|
|
|
|
|
|
+ delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head,
|
|
|
|
|
+ (float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
|
|
+ g + (chunk_size * H_v) * seq + chunk_size * head,
|
|
|
|
|
+ k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
|
|
+ chunk_size, S_v);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
|
|
|
|
|
+
|
|
|
|
|
+ // Process each chunk with all sequences and heads together
|
|
|
|
|
+ for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
|
|
|
|
|
+ GGML_LOG_INFO("\n=== Processing chunk %ld ===\n", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // Create lambdas for tensor access similar to recurrent function
|
|
|
|
|
+ const auto q_chunk = [chunk, src0](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
|
|
|
|
|
+ return ggml_get_f32_nd(src0, i, chunk * chunk_size + token_idx, head, seq);
|
|
|
|
|
+ };
|
|
|
|
|
+ const auto k_chunk = [chunk, src1](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
|
|
|
|
|
+ return ggml_get_f32_nd(src1, i, chunk * chunk_size + token_idx, head, seq);
|
|
|
|
|
+ };
|
|
|
|
|
+ const auto g_chunk = [chunk, src3](int64_t seq, int64_t head, int64_t token_idx) {
|
|
|
|
|
+ return ggml_get_f32_nd(src3, chunk * chunk_size + token_idx, 0, head, seq);
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Allocate per-chunk arrays containing all sequences and heads
|
|
|
|
|
+ float * temp_state = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * attn_inter = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * v_prime = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * g_diff_exp = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * g_last = (float *) malloc(H_v * n_seqs * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Initialize temp_state with zeros for all sequences and heads (state should be empty initially)
|
|
|
|
|
+ memset(temp_state, 0, S_v * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Create temporary arrays for entire chunk
|
|
|
|
|
+ float * q_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * k_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * q_g_exp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ float * attn_v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Fill temporary arrays with data from all sequences and heads
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+ float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+ float * g_ptr = g + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
|
|
+ float * q_g_exp_ptr = q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+
|
|
|
|
|
+ // Fill q, k, decay_mask, and g data
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
- mask[i * chunk_size + j] = (j > i); // True for upper triangular (excluding diagonal)
|
|
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ q_ptr[i * S_v + d] = q_chunk(seq, head, i, d);
|
|
|
|
|
+ k_ptr[i * S_v + d] = k_chunk(seq, head, i, d);
|
|
|
}
|
|
}
|
|
|
|
|
+ g_ptr[i] = g_chunk(seq, head, i);
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Python loop implementation:
|
|
|
|
|
- // q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
|
|
|
|
- // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
|
|
- delta_compute_q_k_attn_f32(q_chunk, k_chunk, decay_mask_chunk, attn, mask, chunk_size, S_k);
|
|
|
|
|
-
|
|
|
|
|
- // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
|
|
- // Calculate the correct state pointer for this head and sequence
|
|
|
|
|
- float * head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- delta_matmul_state_f32(k_cumdecay_chunk, head_state_data, v_prime, chunk_size, S_k, S_v);
|
|
|
|
|
-
|
|
|
|
|
- // v_new = v_i - v_prime
|
|
|
|
|
- delta_tensor_subtract_f32(value_chunk, v_prime, v_new, chunk_size * S_v);
|
|
|
|
|
-
|
|
|
|
|
- // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
|
|
- float * q_g_exp = (float *) malloc(chunk_size * S_k * sizeof(float));
|
|
|
|
|
|
|
+
|
|
|
|
|
+ // Compute q_g_exp = q * g.exp()
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_k; d++) {
|
|
|
|
|
- int64_t q_idx = i * S_k + d;
|
|
|
|
|
- q_g_exp[q_idx] = q_chunk[q_idx] * expf(g_chunk[i]);
|
|
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- delta_matmul_state_f32(q_g_exp, head_state_data, attn_inter, chunk_size, S_k, S_v);
|
|
|
|
|
-
|
|
|
|
|
- // core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
|
|
- float * attn_v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
- delta_matmul_state_f32(attn, v_new, attn_v_new, chunk_size, chunk_size, S_v);
|
|
|
|
|
- delta_tensor_add_f32(attn_inter, attn_v_new, core_attn_out_chunk, chunk_size * S_v);
|
|
|
|
|
-
|
|
|
|
|
- // Store the result in the output tensor
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ print_debug_info(q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
|
|
|
|
|
+ print_debug_info(k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // Step 4: Compute NEW attention matrix for this chunk: attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
|
|
+ // Note: decay_mask[:, :, i] means we need to use the decay_mask for this specific chunk
|
|
|
|
|
+ // The mask applied is the simple causal attention mask: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
|
|
|
|
|
+
|
|
|
|
|
+ // Now compute attention for all sequences and heads together
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ const float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+ const float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+
|
|
|
|
|
+ float * k_trans = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
|
|
+ for (int i = 0; i < S_v; i++) {
|
|
|
|
|
+ for (int j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ k_trans[i * chunk_size + j] = k_ptr[j * S_v + i];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ delta_matmul_f32(q_ptr, k_trans, attn_ptr, chunk_size, chunk_size, S_v);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "q_k_trans", chunk);
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
- if ((chunk * chunk_size + i) >= n_tokens) continue;
|
|
|
|
|
- int64_t output_idx = seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
|
|
|
|
|
- output[output_idx] = core_attn_out_chunk[i * S_v + d];
|
|
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ const float * decay_mask_ptr = (float *) src5->data + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ float attn_val = attn_ptr[i * chunk_size + j] * decay_mask_ptr[i * chunk_size + j];
|
|
|
|
|
+ // Apply simple causal attention mask (upper triangular with diagonal=1)
|
|
|
|
|
+ // This corresponds to: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
|
|
|
|
|
+ if (j > i) {
|
|
|
|
|
+ attn_val = 0.0f;
|
|
|
|
|
+ }
|
|
|
|
|
+ attn_ptr[i * chunk_size + j] = attn_val;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // g_last = g[:, :, i, -1, None, None].exp()
|
|
|
|
|
- *g_last = expf(g_chunk[chunk_size - 1]);
|
|
|
|
|
-
|
|
|
|
|
- // Prepare g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
|
|
|
|
|
- float g_last_val = g_chunk[chunk_size - 1];
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "attn_step4_new_chunk", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
|
|
+ // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
|
|
|
|
|
+ delta_matmul_state_chunk_f32(k_cumdecay, state_data, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
|
|
+ print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // v_new = v_i - v_prime
|
|
|
|
|
+ delta_tensor_subtract_chunk_f32(value, v_prime, v_new, chunk_size * S_v, n_seqs, H_v);
|
|
|
|
|
+ print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
|
|
+ delta_matmul_state_chunk_f32(q_g_exp, state_data, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
|
|
+ print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
|
|
+ // Use regular matrix multiplication for attn @ v_new
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ const float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
+ const float * v_new_ptr = v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+ float * attn_v_new_ptr = attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+
|
|
|
|
|
+ // Compute attn @ v_new: [chunk_size, chunk_size] @ [chunk_size, S_v] -> [chunk_size, S_v]
|
|
|
|
|
+ delta_matmul_f32(attn_ptr, v_new_ptr, attn_v_new_ptr, chunk_size, S_v, chunk_size);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
|
|
|
|
|
+ delta_tensor_add_chunk_f32(attn_inter, attn_v_new, core_attn_out, chunk_size * S_v, n_seqs, H_v);
|
|
|
|
|
+ print_debug_info(core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // Prepare g_last and g_diff_exp for state update
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * g_ptr = g + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
|
|
+ float g_last_val = g_ptr[chunk_size - 1];
|
|
|
|
|
+ g_last[seq * H_v + head] = expf(g_last_val);
|
|
|
|
|
+
|
|
|
|
|
+ float * g_diff_exp_ptr = g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- g_diff_exp[i] = expf(g_last_val - g_chunk[i]);
|
|
|
|
|
|
|
+ float diff = g_last_val - g_ptr[i];
|
|
|
|
|
+ g_diff_exp_ptr[i] = expf(diff);
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // last_recurrent_state = (
|
|
|
|
|
- // last_recurrent_state * g_last
|
|
|
|
|
- // + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
|
|
|
|
- // )
|
|
|
|
|
- float * new_recurrent_state = (float *) malloc(S_v * S_v * sizeof(float));
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- delta_update_recurrent_state_f32(head_state_data, g_last, k_chunk, g_diff_exp, v_new,
|
|
|
|
|
- new_recurrent_state, chunk_size, S_v, S_v);
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- // Store the new state
|
|
|
|
|
- for (int64_t i = 0; i < S_v; i++) {
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ print_debug_info(g_last, H_v * n_seqs, "g_last_chunk", chunk);
|
|
|
|
|
+ print_debug_info(g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ float * k_g_diffexp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < S_v; j++) {
|
|
for (int64_t j = 0; j < S_v; j++) {
|
|
|
- int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
|
|
|
|
|
- new_state[state_idx] = new_recurrent_state[i * S_v + j];
|
|
|
|
|
|
|
+ k_g_diffexp[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + i * S_v + j] =
|
|
|
|
|
+ k_chunk(seq, head, i, j) * g_diff_exp[seq * (chunk_size * H_v) + head * chunk_size + i];
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Update the original state tensor with the new state for the next chunk
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(k_g_diffexp, chunk_size * S_v * H_v * n_seqs, "k_g_diffexp", chunk);
|
|
|
|
|
+ float * k_g_diffexp_T = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
|
|
+ for (int64_t j = 0; j < chunk_size; j++) {
|
|
|
|
|
+ k_g_diffexp_T[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + i * chunk_size + j] =
|
|
|
|
|
+ k_g_diffexp[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + j * S_v + i];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ // for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ // GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
|
|
|
|
|
+ // for (int i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ // GGML_LOG_INFO("[ ");
|
|
|
|
|
+ // for (int j = 0; j < S_v; j++) {
|
|
|
|
|
+ // GGML_LOG_INFO("%.6f", k_g_diffexp[(chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head + i * S_v + j]);
|
|
|
|
|
+ // if (j < chunk_size - 1) {
|
|
|
|
|
+ // GGML_LOG_INFO(", ");
|
|
|
|
|
+ // }
|
|
|
|
|
+ // }
|
|
|
|
|
+ // GGML_LOG_INFO("], \n");
|
|
|
|
|
+ // }
|
|
|
|
|
+ // GGML_LOG_INFO("]\n");
|
|
|
|
|
+ // }
|
|
|
|
|
+ // }
|
|
|
|
|
+
|
|
|
|
|
+ print_debug_info(k_g_diffexp_T, chunk_size * S_v * H_v * n_seqs, "k_g_diffexp_T", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ float * kgd_mul_vnew = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ delta_matmul_f32(k_g_diffexp_T + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
|
|
+ v_new + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
|
|
+ kgd_mul_vnew + (S_v * S_v * H_v) * seq + (S_v * S_v) * head,
|
|
|
|
|
+ S_v, S_v, chunk_size);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(kgd_mul_vnew, S_v * S_v * H_v * n_seqs, "kgd_mul_vnew", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ // for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ // GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
|
|
|
|
|
+ // for (int i = 0; i < S_v; i++) {
|
|
|
|
|
+ // GGML_LOG_INFO("[ ");
|
|
|
|
|
+ // for (int j = 0; j < S_v; j++) {
|
|
|
|
|
+ // GGML_LOG_INFO("%.6f", kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + i * S_v + j]);
|
|
|
|
|
+ // if (j < S_v - 1) {
|
|
|
|
|
+ // GGML_LOG_INFO(", ");
|
|
|
|
|
+ // }
|
|
|
|
|
+ // }
|
|
|
|
|
+ // GGML_LOG_INFO("], \n");
|
|
|
|
|
+ // }
|
|
|
|
|
+ // GGML_LOG_INFO("]\n");
|
|
|
|
|
+ // }
|
|
|
|
|
+ // }
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ for (int i = 0; i < S_v; i++) {
|
|
|
|
|
+ for (int j = 0; j < S_v; j++) {
|
|
|
|
|
+ temp_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
|
|
|
|
|
+ state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] +
|
|
|
|
|
+ kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(temp_state, S_v * S_v * H_v * n_seqs, "temp_state", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // Free temporary memory
|
|
|
|
|
+ free(q_chunk_data);
|
|
|
|
|
+ free(k_chunk_data);
|
|
|
|
|
+ free(q_g_exp);
|
|
|
|
|
+ free(attn_v_new);
|
|
|
|
|
+ free(kgd_mul_vnew);
|
|
|
|
|
+ free(k_g_diffexp_T);
|
|
|
|
|
+ free(k_g_diffexp);
|
|
|
|
|
+
|
|
|
|
|
+ // Store output for this chunk (all sequences and heads)
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
+
|
|
|
|
|
+ // Store output for this chunk
|
|
|
|
|
+ for (int64_t i = 0; i < n_tokens; i++) {
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
+ int64_t output_idx =
|
|
|
|
|
+ seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
|
|
|
|
|
+ output[output_idx] = core_attn_out_ptr[i * S_v + d];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ print_debug_info(output, S_v * H_v * n_tokens * n_seqs, "output", chunk);
|
|
|
|
|
+
|
|
|
|
|
+ // Update state tensor (all sequences and heads)
|
|
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * temp_state_ptr = temp_state + seq * (S_v * S_v * H_v) + head * (S_v * S_v);
|
|
|
|
|
+
|
|
|
for (int64_t i = 0; i < S_v; i++) {
|
|
for (int64_t i = 0; i < S_v; i++) {
|
|
|
for (int64_t j = 0; j < S_v; j++) {
|
|
for (int64_t j = 0; j < S_v; j++) {
|
|
|
- int64_t state_idx = i * S_v + j;
|
|
|
|
|
- head_state_data[state_idx] = new_recurrent_state[state_idx];
|
|
|
|
|
|
|
+ int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
|
|
|
|
|
+ new_state[state_idx] = temp_state_ptr[i * S_v + j];
|
|
|
|
|
+ *(state_ptr(seq, head, i, j)) = temp_state_ptr[i * S_v + j];
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Recalculate head_state_data to point to the updated state for the next iteration
|
|
|
|
|
- head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
|
|
|
|
|
-
|
|
|
|
|
- // Free temporary memory
|
|
|
|
|
- free(attn);
|
|
|
|
|
- free(v_prime);
|
|
|
|
|
- free(v_new);
|
|
|
|
|
- free(attn_inter);
|
|
|
|
|
- free(core_attn_out_chunk);
|
|
|
|
|
- free(g_last);
|
|
|
|
|
- free(g_diff_exp);
|
|
|
|
|
- free(mask);
|
|
|
|
|
- free(q_g_exp);
|
|
|
|
|
- free(attn_v_new);
|
|
|
|
|
- free(new_recurrent_state);
|
|
|
|
|
-
|
|
|
|
|
- // Free the value and k_cumdecay allocated at the beginning of the loop
|
|
|
|
|
- free(value_chunk);
|
|
|
|
|
- free(k_cumdecay);
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
|
|
+ print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
|
|
|
|
|
|
|
|
-static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
|
|
|
|
|
- GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n",
|
|
|
|
|
- name, token, data[0], data[1], data[2], data[3], data[4]);
|
|
|
|
|
- double sum = 0.0;
|
|
|
|
|
- for (unsigned int i = 0; i < size; i++) {
|
|
|
|
|
- sum += data[i];
|
|
|
|
|
|
|
+ free(temp_state);
|
|
|
|
|
+ free(core_attn_out);
|
|
|
|
|
+ free(attn_inter);
|
|
|
|
|
+ free(v_new);
|
|
|
|
|
+ free(v_prime);
|
|
|
|
|
+ free(g_diff_exp);
|
|
|
|
|
+ free(g_last);
|
|
|
}
|
|
}
|
|
|
- GGML_LOG_INFO("sum = %.10f\n", sum);
|
|
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(output + S_v * H_v * n_tokens * n_seqs == new_state);
|
|
|
|
|
+ free(attn);
|
|
|
|
|
+ free(value);
|
|
|
|
|
+ free(k_cumdecay);
|
|
|
|
|
+ free(mask);
|
|
|
|
|
+ free(g);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
@@ -10971,7 +11316,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
|
|
|
|
|
|
|
+ print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
|
|
|
|
|
|
|
|
// 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
|
|
// 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
@@ -10985,7 +11330,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
|
|
|
|
|
|
|
+ print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
|
|
|
|
|
|
|
|
// 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
// 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
@@ -11000,7 +11345,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- //print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
|
|
|
|
|
|
|
+ print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
|
|
|
|
|
|
|
|
// 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
|
|
// 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
@@ -11012,7 +11357,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- //print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
|
|
|
|
|
|
|
+ print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
|
|
|
|
|
|
|
|
// 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
|
|
// 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
@@ -11026,7 +11371,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
|
|
|
|
|
|
|
+ print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
|
|
|
|
|
|
|
|
// 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
// 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
@@ -11040,7 +11385,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- //print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
|
|
|
|
|
|
|
+ print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
|
|
|
|
|
|
|
|
// Store the output for this token (for all seqs and heads)
|
|
// Store the output for this token (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|