|
|
@@ -52,10 +52,8 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
|
|
|
|
|
|
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
|
|
|
-
|
|
|
- struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
|
|
|
- report_tensor_size("beta_sigmoid", beta_sigmoid);
|
|
|
-
|
|
|
+
|
|
|
+ // Merge q, k, v into qkv
|
|
|
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
|
|
|
report_tensor_size("mixed_qkv_qk", mixed_qkv);
|
|
|
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
|
|
|
@@ -68,6 +66,7 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
|
|
|
report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
|
|
|
|
|
|
+ // Apply convolution
|
|
|
struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
|
|
|
report_tensor_size("conv_out", conv_out);
|
|
|
|
|
|
@@ -85,68 +84,36 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
|
|
|
report_tensor_size("conv_out_transposed", conv_out);
|
|
|
|
|
|
- struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
|
|
|
- S_k, // ne0
|
|
|
- H_k, // ne1
|
|
|
- conv_out->ne[1], // ne2 = sequence length (1)
|
|
|
- conv_out->ne[2], // ne3 = batch (1)
|
|
|
- H_k * sizeof(float), // nb1 = stride along H_k
|
|
|
- conv_out->nb[1], // nb2 = stride along sequence dim
|
|
|
- conv_out->nb[2], // nb3 = stride along batch dim
|
|
|
- 0 // offset in bytes
|
|
|
- );
|
|
|
+ // Beta sigmoid
|
|
|
+ struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
|
|
|
+ report_tensor_size("beta_sigmoid", beta_sigmoid);
|
|
|
+
|
|
|
+ // Gate calculations are done elsewhere in llama-model.cpp
|
|
|
+
|
|
|
+ // Re-split the qkv tensors
|
|
|
+ struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2],
|
|
|
+ H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], 0);
|
|
|
report_tensor_size("q_conv_view", q_conv);
|
|
|
|
|
|
- // k projection view
|
|
|
- struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out,
|
|
|
- S_k, // ne0
|
|
|
- H_k, // ne1
|
|
|
- conv_out->ne[1], // ne2
|
|
|
- conv_out->ne[2], // ne3
|
|
|
- H_k * sizeof(float), // nb1
|
|
|
- conv_out->nb[1], // nb2
|
|
|
- conv_out->nb[2], // nb3
|
|
|
- S_k * H_k * sizeof(q->type) // offset = skip q_out
|
|
|
- );
|
|
|
+ struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2],
|
|
|
+ H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], S_k * H_k * sizeof(q->type));
|
|
|
report_tensor_size("k_conv_view", k_conv);
|
|
|
|
|
|
- // v projection view
|
|
|
- struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out,
|
|
|
- S_v, // ne0
|
|
|
- H_v, // ne1
|
|
|
- conv_out->ne[1], // ne2
|
|
|
- conv_out->ne[2], // ne3
|
|
|
- H_v * sizeof(float), // nb1
|
|
|
- conv_out->nb[1], // nb2
|
|
|
- conv_out->nb[2], // nb3
|
|
|
- (2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out
|
|
|
- );
|
|
|
+ struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, S_v, H_v, conv_out->ne[1], conv_out->ne[2], H_v * sizeof(float),
|
|
|
+ conv_out->nb[1], conv_out->nb[2], (2 * S_k * H_k) * sizeof(q->type));
|
|
|
report_tensor_size("v_conv_view", v_conv);
|
|
|
|
|
|
- q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
|
|
|
- report_tensor_size("q_conv_permuted", q_conv);
|
|
|
- k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
|
|
|
- report_tensor_size("k_conv_permuted", k_conv);
|
|
|
- v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);
|
|
|
- report_tensor_size("v_conv_permuted", v_conv);
|
|
|
-
|
|
|
- q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, batch_size, n_tokens);
|
|
|
- report_tensor_size("q_conv_reshaped", q_conv);
|
|
|
- k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, batch_size, n_tokens);
|
|
|
- report_tensor_size("k_conv_reshaped", k_conv);
|
|
|
- v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
|
|
|
- report_tensor_size("v_conv_reshaped", v_conv);
|
|
|
-
|
|
|
struct ggml_tensor * q_broadcast = q_conv;
|
|
|
struct ggml_tensor * k_broadcast = k_conv;
|
|
|
|
|
|
+ // if head keys and value keys are different, repeat to force tensors into matching shapes
|
|
|
if (H_k != H_v) {
|
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
|
int64_t repeat_factor = H_v / H_k;
|
|
|
|
|
|
- q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
|
|
|
+ q_broadcast = ggml_cont_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
|
|
|
report_tensor_size("q_broadcast_reshape1", q_broadcast);
|
|
|
- k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
|
|
|
+ k_broadcast = ggml_cont_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
|
|
|
report_tensor_size("k_broadcast_reshape1", k_broadcast);
|
|
|
|
|
|
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
|
|
|
@@ -160,24 +127,14 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
report_tensor_size("k_broadcast_reshape2", k_broadcast);
|
|
|
}
|
|
|
|
|
|
- struct ggml_tensor * v_reshape = ggml_reshape_4d(ctx, v_conv, S_v, H_v, n_tokens, batch_size);
|
|
|
+ struct ggml_tensor * v_reshape = ggml_cont_4d(ctx, v_conv, S_v, H_v, n_tokens, batch_size);
|
|
|
report_tensor_size("v_reshape", v_reshape);
|
|
|
- struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
|
|
|
- report_tensor_size("v_broadcast", v_broadcast);
|
|
|
- struct ggml_tensor * g_reshape = g;
|
|
|
- report_tensor_size("g_reshape", g_reshape);
|
|
|
- q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
|
- report_tensor_size("q_broadcast_final", q_broadcast);
|
|
|
- k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
|
- report_tensor_size("k_broadcast_final", k_broadcast);
|
|
|
- struct ggml_tensor * beta_reshape = ggml_reshape_4d(ctx, beta_sigmoid, 1, H_v, n_tokens, batch_size);
|
|
|
- report_tensor_size("beta_reshape", beta_reshape);
|
|
|
- struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
|
|
|
+ struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx, beta, 1, H_v, n_tokens, batch_size);
|
|
|
report_tensor_size("beta_broadcast", beta_broadcast);
|
|
|
struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
|
|
|
report_tensor_size("state_broadcast", state_broadcast);
|
|
|
|
|
|
- return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
|
|
|
+ return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_reshape, g, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
|
|
|
}
|
|
|
|
|
|
struct ggml_tensor * ggml_delta_net_op(
|
|
|
@@ -212,9 +169,10 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
const int64_t batch_size = q->ne[3];
|
|
|
|
|
|
const int64_t S_v = v->ne[0];
|
|
|
- const int64_t H_v = v->ne[1];
|
|
|
+ const int64_t H_v = v->ne[1];
|
|
|
|
|
|
GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
|
|
|
+ GGML_ASSERT(H_k == H_v); // we broadcasted the tensors in the main function to guarantee this
|
|
|
|
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
|
|
|
GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
|
|
|
@@ -289,71 +247,28 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
struct ggml_tensor * state_t = state_2d;
|
|
|
report_tensor_size("state_t", state_t);
|
|
|
|
|
|
- struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
|
|
|
+ struct ggml_tensor * state_t_transposed = ggml_cont(ctx, ggml_transpose(ctx, state_t));
|
|
|
report_tensor_size("state_t_transposed", state_t_transposed);
|
|
|
-
|
|
|
+
|
|
|
struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
|
|
|
report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
|
|
|
|
|
|
- struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
|
|
|
+ struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, state_t_transposed, k_t_final_reshaped);
|
|
|
report_tensor_size("kv_mem", kv_mem);
|
|
|
|
|
|
struct ggml_tensor * v_t_final = v_t_reshaped;
|
|
|
struct ggml_tensor * beta_t_final = beta_t_reshaped;
|
|
|
-
|
|
|
- if (H_k != H_v) {
|
|
|
- struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
|
|
|
- struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
|
|
|
- v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
|
|
|
-
|
|
|
- struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
|
|
|
- struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
|
|
|
- beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
|
|
|
- }
|
|
|
-
|
|
|
- struct ggml_tensor * kv_mem_reshaped;
|
|
|
- if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
|
|
|
- kv_mem_reshaped = kv_mem;
|
|
|
- } else if (kv_mem->ne[0] == S_v) {
|
|
|
- kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
|
|
|
- } else {
|
|
|
- report_tensor_size("kv_mem_before_reshape", kv_mem);
|
|
|
- kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
|
|
|
- }
|
|
|
- kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
|
|
|
+
|
|
|
+ struct ggml_tensor * kv_mem_reshaped = ggml_transpose(ctx, kv_mem);
|
|
|
report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
|
|
|
-
|
|
|
- struct ggml_tensor * kv_mem_final;
|
|
|
- if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
|
|
|
- kv_mem_final = kv_mem_reshaped;
|
|
|
- } else {
|
|
|
- kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
|
|
|
- }
|
|
|
- report_tensor_size("kv_mem_final", kv_mem_final);
|
|
|
-
|
|
|
- struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
|
|
|
+
|
|
|
+ struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_reshaped), beta_t_final);
|
|
|
report_tensor_size("delta", delta);
|
|
|
|
|
|
struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
|
|
|
report_tensor_size("delta_reshaped", delta_reshaped);
|
|
|
-
|
|
|
- if (H_k == H_v) {
|
|
|
- k_t_final = k_t_reshaped;
|
|
|
- } else {
|
|
|
- int64_t repeat_factor = H_v / H_k;
|
|
|
- GGML_ASSERT(H_v % H_k == 0);
|
|
|
-
|
|
|
- k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
|
|
|
- report_tensor_size("k_t_final_reshape1", k_t_final);
|
|
|
-
|
|
|
- k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
|
|
|
- report_tensor_size("k_t_final_repeat", k_t_final);
|
|
|
-
|
|
|
- k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
|
|
|
- report_tensor_size("k_t_final_reshape2", k_t_final);
|
|
|
- }
|
|
|
-
|
|
|
- k_t_final = ggml_cont(ctx, k_t_final);
|
|
|
+
|
|
|
+ k_t_final = ggml_cont(ctx, k_t_reshaped);
|
|
|
report_tensor_size("k_t_final_cont", k_t_final);
|
|
|
|
|
|
struct ggml_tensor * k_t_for_outer;
|