|
|
@@ -9861,6 +9861,14 @@ void ggml_compute_forward_unary(
|
|
|
{
|
|
|
ggml_compute_forward_exp(params, dst);
|
|
|
} break;
|
|
|
+ case GGML_UNARY_OP_EXPM1:
|
|
|
+ {
|
|
|
+ ggml_compute_forward_expm1(params, dst);
|
|
|
+ } break;
|
|
|
+ case GGML_UNARY_OP_SOFTPLUS:
|
|
|
+ {
|
|
|
+ ggml_compute_forward_softplus(params, dst);
|
|
|
+ } break;
|
|
|
default:
|
|
|
{
|
|
|
GGML_ABORT("fatal error");
|
|
|
@@ -10874,6 +10882,200 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
|
|
|
+ GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n",
|
|
|
+ name, token, data[0], data[1], data[2], data[3], data[4]);
|
|
|
+ double sum = 0.0;
|
|
|
+ for (unsigned int i = 0; i < size; i++) {
|
|
|
+ sum += data[i];
|
|
|
+ }
|
|
|
+ GGML_LOG_INFO("sum = %.10f\n", sum);
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
|
+ const struct ggml_tensor * src0 = dst->src[0]; // q_tokens
|
|
|
+ const struct ggml_tensor * src1 = dst->src[1]; // k_tokens
|
|
|
+ const struct ggml_tensor * src2 = dst->src[2]; // v_tokens
|
|
|
+ const struct ggml_tensor * src3 = dst->src[3]; // g_tokens_exp
|
|
|
+ const struct ggml_tensor * src4 = dst->src[4]; // beta_tokens
|
|
|
+ const struct ggml_tensor * src5 = dst->src[5]; // state
|
|
|
+ // src6, src7, src8 are nullptr in recurrent version
|
|
|
+
|
|
|
+ const int64_t H_v = (int64_t) dst->op_params[0];
|
|
|
+ const int64_t S_k = (int64_t) dst->op_params[1];
|
|
|
+ const int64_t S_v = (int64_t) dst->op_params[2];
|
|
|
+ const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
|
+ const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
|
+ const int64_t n_seqs = src0->ne[3]; // q tensor has n_seqs in dim 3
|
|
|
+
|
|
|
+ // Add assertions to verify tensor dimensions
|
|
|
+ GGML_ASSERT(src0->ne[3] == n_seqs); // q tensor
|
|
|
+ GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
|
|
|
+ GGML_ASSERT(src2->ne[3] == n_seqs); // v tensor
|
|
|
+ GGML_ASSERT(src3->ne[3] == n_seqs); // g tensor
|
|
|
+ GGML_ASSERT(src4->ne[3] == n_seqs); // beta tensor
|
|
|
+ GGML_ASSERT(src5->ne[3] == n_seqs); // state tensor
|
|
|
+
|
|
|
+ float * dst_data = (float *) dst->data;
|
|
|
+ // Output is first part, state is second part
|
|
|
+ float * output = dst_data; // [S_v * H_v * n_tokens * n_seqs]
|
|
|
+ float * final_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * S_v * H_v * n_seqs]
|
|
|
+
|
|
|
+ const int ith = params->ith;
|
|
|
+ // const int nth = params->nth;
|
|
|
+
|
|
|
+ // Clear output and new state section
|
|
|
+ if (ith == 0) {
|
|
|
+ memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
|
|
|
+ } else {
|
|
|
+ return; // only calculate on one thread
|
|
|
+ }
|
|
|
+
|
|
|
+ float * state_data = (float *) src5->data; // state is now src5
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src2));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src3));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src4));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src5));
|
|
|
+
|
|
|
+ const auto state_ptr = [state_data, src5] (int64_t seq, int64_t head, int64_t i, int64_t j) {
|
|
|
+ return state_data + (j * src5->nb[0] / sizeof(float)) + (i * src5->nb[1] / sizeof(float)) +
|
|
|
+ (head * src5->nb[2] / sizeof(float)) + (seq * src5->nb[3] / sizeof(float));
|
|
|
+ };
|
|
|
+
|
|
|
+ // Process each token sequentially across all sequences and heads (recurrent processing)
|
|
|
+ // Following the PyTorch reference: for each token i, process all sequences and heads
|
|
|
+ for (int64_t token = 0; token < n_tokens; token++) {
|
|
|
+ const auto q_t = [token, src0] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src0, token, i, head, seq); };
|
|
|
+ const auto k_t = [token, src1] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src1, token, i, head, seq); };
|
|
|
+ const auto v_t = [token, src2] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src2, token, i, head, seq); };
|
|
|
+ const auto g_exp_t = [token, src3] (int64_t seq, int64_t head) { return ggml_get_f32_nd(src3, token, 0, head, seq); };
|
|
|
+ const auto beta_t = [token, src4] (int64_t seq, int64_t head) { return ggml_get_f32_nd(src4, token, 0, head, seq); };
|
|
|
+
|
|
|
+ float * delta = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * kv_mem = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
|
|
|
+ float * attn_out_t = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
|
|
|
+
|
|
|
+ // Create temporary arrays for processing all sequences and heads at once
|
|
|
+ float * temp_state = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
|
|
|
+
|
|
|
+ // Initialize temp_state with current state values for all sequences and heads
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
|
|
|
+ temp_state[idx] = *(state_ptr(seq, head, i, j));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
|
|
|
+
|
|
|
+ // 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ float g_exp = g_exp_t(seq, head);
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
|
|
|
+ temp_state[idx] *= g_exp;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
|
|
|
+
|
|
|
+ // 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ kv_mem[seq * H_v * S_v + head * S_v + j] = 0.0f;
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
|
|
|
+ // This implements: (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
+ kv_mem[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * k_t(seq, head, i);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
|
|
|
+
|
|
|
+ // 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ float beta_val = beta_t(seq, head);
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ delta[seq * H_v * S_v + head * S_v + j] =
|
|
|
+ (v_t(seq, head, j) - kv_mem[seq * H_v * S_v + head * S_v + j]) * beta_val;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
|
|
|
+
|
|
|
+ // 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
|
|
|
+ // k_t[i] * delta[j] (where delta is treated as column vector)
|
|
|
+ temp_state[state_idx] += k_t(seq, head, i) * delta[seq * H_v * S_v + head * S_v + j];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
|
|
|
+
|
|
|
+ // 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ attn_out_t[seq * H_v * S_v + head * S_v + j] = 0.0f;
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
|
|
|
+ attn_out_t[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * q_t(seq, head, i);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
|
|
|
+
|
|
|
+ // Store the output for this token (for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t d = 0; d < S_v; d++) {
|
|
|
+ int64_t output_idx = d + head * S_v + token * (S_v * H_v) + seq * (S_v * H_v * n_tokens);
|
|
|
+ output[output_idx] = attn_out_t[seq * H_v * S_v + head * S_v + d];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update the working state for next token iteration (in the state tensor for all seqs and heads)
|
|
|
+ for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
+ for (int64_t head = 0; head < H_v; head++) {
|
|
|
+ for (int64_t i = 0; i < S_v; i++) {
|
|
|
+ for (int64_t j = 0; j < S_v; j++) {
|
|
|
+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
|
|
|
+ *(state_ptr(seq, head, i, j)) = temp_state[state_idx];
|
|
|
+
|
|
|
+ // Store the final state for this head and sequence (for output)
|
|
|
+ int64_t final_state_idx = i + j * S_v + head * (S_v * S_v) + seq * (S_v * S_v * H_v);
|
|
|
+ final_state[final_state_idx] = temp_state[state_idx];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ free(temp_state);
|
|
|
+ free(delta);
|
|
|
+ free(kv_mem);
|
|
|
+ free(attn_out_t);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// ggml_compute_forward_rwkv_wkv7
|
|
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
const ggml_compute_params * params,
|