|
|
@@ -10769,10 +10769,11 @@ static void delta_compute_q_k_attn_chunk_f32(const float * q, const float * k, c
|
|
|
// Helper function for matrix multiplication with state tensors for entire chunk
|
|
|
static void delta_matmul_state_chunk_f32(const float * a, const float * state, float * dst,
|
|
|
const int64_t rows_a, const int64_t cols_a, const int64_t cols_state,
|
|
|
- const int64_t n_seqs, const int64_t H_v) {
|
|
|
+ const int64_t n_seqs, const int64_t H_v, int chunk, int num_chunks) {
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- const float * a_ptr = a + seq * (rows_a * cols_a * H_v) + head * (rows_a * cols_a);
|
|
|
+ const float * a_ptr = chunk < 0 ? a + seq * (rows_a * cols_a * H_v) + head * (rows_a * cols_a) :
|
|
|
+ a + seq * (rows_a * cols_a * H_v * num_chunks) + (head * num_chunks + chunk) * (rows_a * cols_a);
|
|
|
const float * state_ptr = state + seq * (cols_a * cols_state * H_v) + head * (cols_a * cols_state);
|
|
|
float * dst_ptr = dst + seq * (rows_a * cols_state * H_v) + head * (rows_a * cols_state);
|
|
|
delta_matmul_state_f32(a_ptr, state_ptr, dst_ptr, rows_a, cols_a, cols_state);
|
|
|
@@ -10817,10 +10818,10 @@ static void delta_update_recurrent_state_chunk_f32(const float * state, const fl
|
|
|
|
|
|
// Helper function for element-wise tensor subtraction for entire chunk
|
|
|
static void delta_tensor_subtract_chunk_f32(const float * a, const float * b, float * dst, const int64_t size,
|
|
|
- const int64_t n_seqs, const int64_t H_v) {
|
|
|
+ const int64_t n_seqs, const int64_t H_v, int num_chunks, int chunk) {
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- const float * a_ptr = a + seq * (size * H_v) + head * size;
|
|
|
+ const float * a_ptr = a + seq * (size * num_chunks * H_v) + (head * num_chunks + chunk) * size;
|
|
|
const float * b_ptr = b + seq * (size * H_v) + head * size;
|
|
|
float * dst_ptr = dst + seq * (size * H_v) + head * size;
|
|
|
delta_tensor_subtract_f32(a_ptr, b_ptr, dst_ptr, size);
|
|
|
@@ -10889,8 +10890,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
float * dst_data = (float *) dst->data;
|
|
|
// Following GLA pattern: output is first part, state is second part
|
|
|
- float * output = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
|
|
|
- float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v, S_v * H_v, 1, n_seqs]
|
|
|
+ float * output = dst_data; // [S_v, H_v, n_tokens, n_seqs] - only real sequence length, not padded
|
|
|
+ float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v, S_v, H_v, n_seqs]
|
|
|
|
|
|
const int ith = params->ith;
|
|
|
// const int nth = params->nth; // nth is unused
|
|
|
@@ -10968,7 +10969,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
|
|
|
+ print_debug_info(k_cumdecay, chunk_size * S_v * H_v * num_chunks * n_seqs, "k_cumdecay", -1);
|
|
|
|
|
|
// Process each chunk with all sequences and heads together
|
|
|
for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
|
|
|
@@ -10984,27 +10985,27 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
};
|
|
|
|
|
|
// Allocate per-chunk arrays containing all sequences and heads
|
|
|
- float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * attn_inter = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * v_prime = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * g_diff_exp = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
- float * g_last = (float *) malloc(H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_attn_inter = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_v_prime = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_g_diff_exp = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_g_last = (float *) malloc(H_v * n_seqs * sizeof(float));
|
|
|
|
|
|
// Create temporary arrays for entire chunk
|
|
|
- float * q_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * k_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * q_g_exp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
- float * attn_v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_q_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_k_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_q_g_exp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * pc_attn_v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
|
|
|
// Fill temporary arrays with data from all sequences and heads
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
- float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ float * q_ptr = pc_q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ float * k_ptr = pc_k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
float * g_ptr = g + (chunk_size * H_v * num_chunks) * seq + chunk_size * (head * num_chunks + chunk);
|
|
|
|
|
|
- float * q_g_exp_ptr = q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ float * q_g_exp_ptr = pc_q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
// Fill q, k, decay_mask, and g data
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
@@ -11024,8 +11025,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- print_debug_info(q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
|
|
|
- print_debug_info(k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
|
|
|
+ print_debug_info(pc_q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
|
|
|
+ print_debug_info(pc_k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
|
|
|
|
|
|
// Step 4: Compute NEW attention matrix for this chunk: attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
// Note: decay_mask[:, :, i] means we need to use the decay_mask for this specific chunk
|
|
|
@@ -11035,8 +11036,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
|
|
|
- const float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
- const float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ const float * q_ptr = pc_q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ const float * k_ptr = pc_k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
float * k_trans = (float *) malloc(chunk_size * S_v * sizeof(float));
|
|
|
for (int i = 0; i < S_v; i++) {
|
|
|
@@ -11072,41 +11073,41 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
// k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
|
|
|
- delta_matmul_state_chunk_f32(k_cumdecay, new_state, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
- print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
|
|
|
+ delta_matmul_state_chunk_f32(k_cumdecay, new_state, pc_v_prime, chunk_size, S_v, S_v, n_seqs, H_v, chunk, num_chunks);
|
|
|
+ print_debug_info(pc_v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
|
|
|
|
|
|
// v_new = v_i - v_prime
|
|
|
- delta_tensor_subtract_chunk_f32(value, v_prime, v_new, chunk_size * S_v, n_seqs, H_v);
|
|
|
- print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
|
|
|
+ delta_tensor_subtract_chunk_f32(value, pc_v_prime, pc_v_new, chunk_size * S_v, n_seqs, H_v, num_chunks, chunk);
|
|
|
+ print_debug_info(pc_v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
|
|
|
|
|
|
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
- delta_matmul_state_chunk_f32(q_g_exp, new_state, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
- print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
|
|
|
+ delta_matmul_state_chunk_f32(pc_q_g_exp, new_state, pc_attn_inter, chunk_size, S_v, S_v, n_seqs, H_v, -1, -1);
|
|
|
+ print_debug_info(pc_attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
|
|
|
|
|
|
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
// Use regular matrix multiplication for attn @ v_new
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
const float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
|
|
|
- const float * v_new_ptr = v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
- float * attn_v_new_ptr = attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ const float * v_new_ptr = pc_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ float * attn_v_new_ptr = pc_attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
// Compute attn @ v_new: [chunk_size, chunk_size] @ [chunk_size, S_v] -> [chunk_size, S_v]
|
|
|
delta_matmul_f32(attn_ptr, v_new_ptr, attn_v_new_ptr, chunk_size, S_v, chunk_size);
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
|
|
|
- delta_tensor_add_chunk_f32(attn_inter, attn_v_new, core_attn_out, chunk_size * S_v, n_seqs, H_v);
|
|
|
- print_debug_info(core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
|
|
|
+ print_debug_info(pc_attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
|
|
|
+ delta_tensor_add_chunk_f32(pc_attn_inter, pc_attn_v_new, pc_core_attn_out, chunk_size * S_v, n_seqs, H_v);
|
|
|
+ print_debug_info(pc_core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
|
|
|
|
|
|
// Prepare g_last and g_diff_exp for state update
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * g_ptr = g + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
+ float * g_ptr = g + seq * (chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * chunk_size;
|
|
|
float g_last_val = g_ptr[chunk_size - 1];
|
|
|
- g_last[seq * H_v + head] = expf(g_last_val);
|
|
|
+ pc_g_last[seq * H_v + head] = expf(g_last_val);
|
|
|
|
|
|
- float * g_diff_exp_ptr = g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
+ float * g_diff_exp_ptr = pc_g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
float diff = g_last_val - g_ptr[i];
|
|
|
g_diff_exp_ptr[i] = expf(diff);
|
|
|
@@ -11114,8 +11115,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- print_debug_info(g_last, H_v * n_seqs, "g_last_chunk", chunk);
|
|
|
- print_debug_info(g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
|
|
|
+ print_debug_info(pc_g_last, H_v * n_seqs, "g_last_chunk", chunk);
|
|
|
+ print_debug_info(pc_g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
|
|
|
|
|
|
float * k_g_diffexp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
@@ -11123,7 +11124,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int64_t i = 0; i < chunk_size; i++) {
|
|
|
for (int64_t j = 0; j < S_v; j++) {
|
|
|
k_g_diffexp[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + i * S_v + j] =
|
|
|
- k_chunk(seq, head, i, j) * g_diff_exp[seq * (chunk_size * H_v) + head * chunk_size + i];
|
|
|
+ k_chunk(seq, head, i, j) * pc_g_diff_exp[seq * (chunk_size * H_v) + head * chunk_size + i];
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -11165,7 +11166,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
delta_matmul_f32(k_g_diffexp_T + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
- v_new + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
+ pc_v_new + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
|
|
|
kgd_mul_vnew + (S_v * S_v * H_v) * seq + (S_v * S_v) * head,
|
|
|
S_v, S_v, chunk_size);
|
|
|
}
|
|
|
@@ -11194,7 +11195,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int i = 0; i < S_v; i++) {
|
|
|
for (int j = 0; j < S_v; j++) {
|
|
|
new_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
|
|
|
- state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] +
|
|
|
+ state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * pc_g_last[seq * H_v + head] +
|
|
|
kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
|
|
|
}
|
|
|
}
|
|
|
@@ -11203,10 +11204,10 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "state_end_chunk", chunk);
|
|
|
|
|
|
// Free temporary memory
|
|
|
- free(q_chunk_data);
|
|
|
- free(k_chunk_data);
|
|
|
- free(q_g_exp);
|
|
|
- free(attn_v_new);
|
|
|
+ free(pc_q_chunk_data);
|
|
|
+ free(pc_k_chunk_data);
|
|
|
+ free(pc_q_g_exp);
|
|
|
+ free(pc_attn_v_new);
|
|
|
free(kgd_mul_vnew);
|
|
|
free(k_g_diffexp_T);
|
|
|
free(k_g_diffexp);
|
|
|
@@ -11214,7 +11215,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// Store output for this chunk (all sequences and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
+ float * core_attn_out_ptr = pc_core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
|
|
|
|
|
|
// Compute number of tokens for this chunk (chunk_size unless this is the last chunk)
|
|
|
int64_t n_tokens_chunk = chunk == num_chunks - 1 ? n_tokens % chunk_size : chunk_size;
|
|
|
@@ -11244,12 +11245,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// }
|
|
|
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
|
|
|
|
|
|
- free(core_attn_out);
|
|
|
- free(attn_inter);
|
|
|
- free(v_new);
|
|
|
- free(v_prime);
|
|
|
- free(g_diff_exp);
|
|
|
- free(g_last);
|
|
|
+ free(pc_core_attn_out);
|
|
|
+ free(pc_attn_inter);
|
|
|
+ free(pc_v_new);
|
|
|
+ free(pc_v_prime);
|
|
|
+ free(pc_g_diff_exp);
|
|
|
+ free(pc_g_last);
|
|
|
}
|
|
|
|
|
|
GGML_ASSERT(output + S_v * H_v * n_tokens * n_seqs == new_state);
|