|
|
@@ -10865,7 +10865,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
float * dst_data = (float *) dst->data;
|
|
|
// Following GLA pattern: output is first part, state is second part
|
|
|
float * output = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
|
|
|
- float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
+ float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v, S_v * H_v, 1, n_seqs]
|
|
|
|
|
|
const int ith = params->ith;
|
|
|
// const int nth = params->nth; // nth is unused
|
|
|
@@ -10884,6 +10884,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
float * state_data = (float *) src4->data;
|
|
|
|
|
|
+ // Init new state with initial state (will probably be zeroes)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ new_state[seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j] =
|
|
|
+ state_data[seq * src4->nb[3] / sizeof(float) + (head * S_v + i) * src4->nb[1] / sizeof(float) + j * src4->nb[0] / sizeof(float)];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "init_state", -1);
|
|
|
+
|
|
|
+
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
GGML_ASSERT(ggml_is_contiguous(src2));
|
|
|
@@ -10896,12 +10910,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
// int64_t total_params = n_seqs * H_v * num_chunks;
|
|
|
// int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
|
|
|
-
|
|
|
- // Create helper lambda for state tensor access
|
|
|
- const auto state_ptr = [state_data, src4] (int64_t seq, int64_t head, int64_t i, int64_t j) {
|
|
|
- return state_data + (j * src4->nb[0] / sizeof(float)) + (i * src4->nb[1] / sizeof(float)) +
|
|
|
- (head * src4->nb[2] / sizeof(float)) + (seq * src4->nb[3] / sizeof(float));
|
|
|
- };
|
|
|
|
|
|
float * attn = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
float * value = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
@@ -11048,7 +11056,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
|
|
|
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
// k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
|
|
|
- delta_matmul_state_chunk_f32(k_cumdecay, state_data, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
+ delta_matmul_state_chunk_f32(k_cumdecay, new_state, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
|
|
|
|
|
|
// v_new = v_i - v_prime
|
|
|
@@ -11056,7 +11064,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
|
|
|
|
|
|
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
- delta_matmul_state_chunk_f32(q_g_exp, state_data, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
+ delta_matmul_state_chunk_f32(q_g_exp, new_state, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
|
|
|
print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
|
|
|
|
|
|
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
@@ -11203,19 +11211,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
}
|
|
|
print_debug_info(output, S_v * H_v * n_tokens * n_seqs, "output", chunk);
|
|
|
-
|
|
|
- // Update state tensor (all sequences and heads)
|
|
|
+ GGML_LOG_INFO("\nFull output tensor: \n\n");
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
- float * temp_state_ptr = temp_state + seq * (S_v * S_v * H_v) + head * (S_v * S_v);
|
|
|
-
|
|
|
- for (int64_t i = 0; i < S_v; i++) {
|
|
|
- for (int64_t j = 0; j < S_v; j++) {
|
|
|
- int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
|
|
|
- new_state[state_idx] = temp_state_ptr[i * S_v + j];
|
|
|
- *(state_ptr(seq, head, i, j)) = temp_state_ptr[i * S_v + j];
|
|
|
+ GGML_LOG_INFO("\n[ ");
|
|
|
+ for (int64_t i = 0; i < n_tokens; i++) {
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ GGML_LOG_INFO("%.4f ", output[seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d]);
|
|
|
}
|
|
|
}
|
|
|
+ GGML_LOG_INFO(" ]");
|
|
|
}
|
|
|
}
|
|
|
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
|