|
|
@@ -2975,7 +2975,7 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
|
ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0];
|
|
|
- float c = *((float *) &(dst->op_params[1]));
|
|
|
+ float c = ggml_get_op_params_f32(dst, 1);
|
|
|
bool keep_org_val = isnan(c);
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
@@ -10902,7 +10902,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
// src6, src7, src8 are nullptr in recurrent version
|
|
|
|
|
|
const int64_t H_v = (int64_t) dst->op_params[0];
|
|
|
- const int64_t S_k = (int64_t) dst->op_params[1];
|
|
|
const int64_t S_v = (int64_t) dst->op_params[2];
|
|
|
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
|
|
|
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
|
|
|
@@ -10972,7 +10971,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
|
|
|
+ //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
|
|
|
|
|
|
// 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
@@ -10986,7 +10985,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
|
|
|
+ //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
|
|
|
|
|
|
// 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
@@ -11001,7 +11000,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
|
|
|
+ //print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
|
|
|
|
|
|
// 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
@@ -11013,7 +11012,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
|
|
|
+ //print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
|
|
|
|
|
|
// 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
@@ -11027,7 +11026,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
|
|
|
+ //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
|
|
|
|
|
|
// 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
@@ -11041,7 +11040,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
|
|
|
+ //print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
|
|
|
|
|
|
// Store the output for this token (for all seqs and heads)
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|