|
|
@@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
|
|
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
|
|
|
|
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
|
|
cb(q, "q", il);
|
|
|
@@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|
|
struct ggml_tensor * v =
|
|
|
ggml_view_3d(ctx, kv.v_l[il],
|
|
|
n_embd_head_v, n_kv, n_head_kv,
|
|
|
- ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
|
|
|
- ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
|
|
|
+ ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
|
|
|
+ ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
|
|
|
0);
|
|
|
cb(v, "v", il);
|
|
|
|
|
|
@@ -6655,7 +6656,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
|
}
|
|
|
|
|
|
- cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
|
|
+ cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
|
|
} else {
|
|
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
|
|
cb(kq, "kq", il);
|
|
|
@@ -6700,7 +6701,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
|
|
cb(kqv_merged, "kqv_merged", il);
|
|
|
|
|
|
- cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
|
|
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
|
|
cb(cur, "kqv_merged_cont", il);
|
|
|
}
|
|
|
|