|
@@ -10530,62 +10530,73 @@ static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * k
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Helper function to apply triangular updates to entire chunk (all sequences and heads)
|
|
// Helper function to apply triangular updates to entire chunk (all sequences and heads)
|
|
|
-static void delta_apply_triangular_updates_chunk_f32(float * attn, const int64_t chunk_size,
|
|
|
|
|
- const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
|
|
+static void delta_apply_triangular_updates_chunk_f32(float * attn,
|
|
|
|
|
+ const int64_t chunk_size,
|
|
|
|
|
+ const int64_t n_seqs,
|
|
|
|
|
+ const int64_t H_v,
|
|
|
|
|
+ int num_chunks) {
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
- for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
- float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
-
|
|
|
|
|
- // Apply triangular updates following the Python reference exactly:
|
|
|
|
|
- // for i in range(1, chunk_size):
|
|
|
|
|
- // row = attn[..., i, :i].clone()
|
|
|
|
|
- // sub = attn[..., :i, :i].clone()
|
|
|
|
|
- // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
- for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
|
|
- // Create temporary storage for row and sub to avoid modifying during computation
|
|
|
|
|
- float * row = (float *) malloc(i * sizeof(float));
|
|
|
|
|
- float * sub = (float *) malloc(i * i * sizeof(float));
|
|
|
|
|
-
|
|
|
|
|
- // Copy row = attn[..., i, :i]
|
|
|
|
|
- for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
- row[j] = attn_ptr[i * chunk_size + j];
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Copy sub = attn[..., :i, :i]
|
|
|
|
|
- for (int64_t k = 0; k < i; k++) {
|
|
|
|
|
|
|
+ for (int i = 0; i < num_chunks; i++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + (head * num_chunks + i) * (chunk_size * chunk_size);
|
|
|
|
|
+
|
|
|
|
|
+ // Apply triangular updates following the Python reference exactly:
|
|
|
|
|
+ // for i in range(1, chunk_size):
|
|
|
|
|
+ // row = attn[..., i, :i].clone()
|
|
|
|
|
+ // sub = attn[..., :i, :i].clone()
|
|
|
|
|
+ // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
+ for (int64_t i = 1; i < chunk_size; i++) {
|
|
|
|
|
+ // Create temporary storage for row and sub to avoid modifying during computation
|
|
|
|
|
+ float * row = (float *) malloc(i * sizeof(float));
|
|
|
|
|
+ float * sub = (float *) malloc(i * i * sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ // Copy row = attn[..., i, :i]
|
|
|
for (int64_t j = 0; j < i; j++) {
|
|
for (int64_t j = 0; j < i; j++) {
|
|
|
- sub[k * i + j] = attn_ptr[k * chunk_size + j];
|
|
|
|
|
|
|
+ row[j] = attn_ptr[i * chunk_size + j];
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Compute updates for each j in :i
|
|
|
|
|
- for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
- // Compute (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
- float sum_val = 0.0f;
|
|
|
|
|
|
|
+
|
|
|
|
|
+ // Copy sub = attn[..., :i, :i]
|
|
|
for (int64_t k = 0; k < i; k++) {
|
|
for (int64_t k = 0; k < i; k++) {
|
|
|
- sum_val += row[k] * sub[k * i + j];
|
|
|
|
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
+ sub[k * i + j] = attn_ptr[k * chunk_size + j];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Compute updates for each j in :i
|
|
|
|
|
+ for (int64_t j = 0; j < i; j++) {
|
|
|
|
|
+ // Compute (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
|
|
+ float sum_val = 0.0f;
|
|
|
|
|
+ for (int64_t k = 0; k < i; k++) {
|
|
|
|
|
+ sum_val += row[k] * sub[k * i + j];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Update: attn[..., i, j] = row[j] + sum_val
|
|
|
|
|
+ attn_ptr[i * chunk_size + j] = row[j] + sum_val;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Update: attn[..., i, j] = row[j] + sum_val
|
|
|
|
|
- attn_ptr[i * chunk_size + j] = row[j] + sum_val;
|
|
|
|
|
|
|
+
|
|
|
|
|
+ free(row);
|
|
|
|
|
+ free(sub);
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- free(row);
|
|
|
|
|
- free(sub);
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Helper function to add identity matrix to entire chunk (all sequences and heads)
|
|
// Helper function to add identity matrix to entire chunk (all sequences and heads)
|
|
|
-static void delta_add_identity_matrix_chunk_f32(float * matrix, const int64_t chunk_size,
|
|
|
|
|
- const int64_t n_seqs, const int64_t H_v) {
|
|
|
|
|
|
|
+static void delta_add_identity_matrix_chunk_f32(float * matrix,
|
|
|
|
|
+ const int64_t chunk_size,
|
|
|
|
|
+ const int64_t n_seqs,
|
|
|
|
|
+ const int64_t H_v,
|
|
|
|
|
+ int num_chunks) {
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
- for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
- float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
|
|
|
|
|
- // Add identity matrix directly
|
|
|
|
|
- for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
- matrix_ptr[i * chunk_size + i] += 1.0f;
|
|
|
|
|
|
|
+ for (int i = 0; i < num_chunks; i++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
+ float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) +
|
|
|
|
|
+ (head * num_chunks + i) * (chunk_size * chunk_size);
|
|
|
|
|
+ // Add identity matrix directly
|
|
|
|
|
+ for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
|
|
+ matrix_ptr[i * chunk_size + i] += 1.0f;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -10617,15 +10628,19 @@ static void delta_compute_value_f32(const float * attn,
|
|
|
const int64_t chunk_size,
|
|
const int64_t chunk_size,
|
|
|
const int64_t v_head_dim,
|
|
const int64_t v_head_dim,
|
|
|
const int64_t n_heads,
|
|
const int64_t n_heads,
|
|
|
- const int64_t n_seqs) {
|
|
|
|
|
|
|
+ const int64_t n_seqs,
|
|
|
|
|
+ int num_chunks) {
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
- for (int64_t head = 0; head < n_heads; head++) {
|
|
|
|
|
- delta_matmul_f32(
|
|
|
|
|
- attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * head,
|
|
|
|
|
- v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
|
|
|
|
|
- value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
|
|
|
|
|
- chunk_size, v_head_dim, chunk_size);
|
|
|
|
|
|
|
+ for (int i = 0; i < num_chunks; i++) {
|
|
|
|
|
+ for (int64_t head = 0; head < n_heads; head++) {
|
|
|
|
|
+ delta_matmul_f32(
|
|
|
|
|
+ attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * (head * num_chunks + i),
|
|
|
|
|
+ v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
|
|
|
|
|
+ value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
|
|
|
|
|
+ chunk_size, v_head_dim, chunk_size);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -10913,11 +10928,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// int64_t total_params = n_seqs * H_v * num_chunks;
|
|
// int64_t total_params = n_seqs * H_v * num_chunks;
|
|
|
// int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
|
|
// int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
|
|
|
|
|
|
|
|
- float * attn = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
- float * value = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
- float * k_cumdecay = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
|
|
+ float * attn = (float *) malloc(chunk_size * chunk_size * H_v * num_chunks * n_seqs * sizeof(float));
|
|
|
|
|
+ float * value = (float *) malloc(chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof(float));
|
|
|
|
|
+ float * k_cumdecay = (float *) malloc(chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof(float));
|
|
|
bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
|
|
bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
|
|
|
- float * g = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
|
|
|
|
+ float * g = (float *) malloc(chunk_size * H_v * num_chunks * n_seqs * sizeof(float));
|
|
|
|
|
|
|
|
// Create upper triangular mask for causal attention (exclude diagonal)
|
|
// Create upper triangular mask for causal attention (exclude diagonal)
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
@@ -10934,18 +10949,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// This corresponds to the reference implementation:
|
|
// This corresponds to the reference implementation:
|
|
|
// for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
// for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
// attn = attn + torch.eye(chunk_size)
|
|
// attn = attn + torch.eye(chunk_size)
|
|
|
- delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v);
|
|
|
|
|
- delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v);
|
|
|
|
|
|
|
+ delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
|
|
|
|
|
+ delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
|
|
|
|
|
|
|
|
// Compute value = attn @ v_beta
|
|
// Compute value = attn @ v_beta
|
|
|
- delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs);
|
|
|
|
|
|
|
+ delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs, num_chunks);
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
- for (int64_t head = 0; head < H_v; head++) {
|
|
|
|
|
|
|
+ for (int i = 0; i < num_chunks; i++) {
|
|
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head,
|
|
delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head,
|
|
|
(float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
(float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
g + (chunk_size * H_v) * seq + chunk_size * head,
|
|
g + (chunk_size * H_v) * seq + chunk_size * head,
|
|
|
k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
chunk_size, S_v);
|
|
chunk_size, S_v);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
|
|
print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
|
|
@@ -10996,7 +11013,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
|
|
// Compute q_g_exp = q * g.exp()
|
|
// Compute q_g_exp = q * g.exp()
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
- for (int64_t d = 0; d < S_v; d++) {
|
|
|
|
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
|
|
q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -11196,8 +11213,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
|
|
|
|
+ // Compute number of tokens for this chunk (chunk_size unless this is the last chunk)
|
|
|
|
|
+ int64_t n_tokens_chunk = chunk == num_chunks - 1 ? n_tokens % chunk_size : chunk_size;
|
|
|
|
|
+
|
|
|
// Store output for this chunk
|
|
// Store output for this chunk
|
|
|
- for (int64_t i = 0; i < n_tokens; i++) {
|
|
|
|
|
|
|
+ for (int64_t i = 0; i < n_tokens_chunk; i++) {
|
|
|
for (int64_t d = 0; d < S_v; d++) {
|
|
for (int64_t d = 0; d < S_v; d++) {
|
|
|
int64_t output_idx =
|
|
int64_t output_idx =
|
|
|
seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
|
|
seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
|