|
@@ -51,14 +51,11 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[3] == n_tokens);
|
|
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[3] == n_tokens);
|
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
|
|
|
|
|
|
|
|
- // Validate g dimensions - g should be [S_v, H_v, n_tokens, batch_size] based on actual tensor layout
|
|
|
|
|
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
|
|
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
|
|
|
|
|
|
|
|
- // Apply sigmoid to beta for gating
|
|
|
|
|
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
|
|
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
|
|
|
report_tensor_size("beta_sigmoid", beta_sigmoid);
|
|
report_tensor_size("beta_sigmoid", beta_sigmoid);
|
|
|
|
|
|
|
|
- // Concatenate q, k, v for convolution processing
|
|
|
|
|
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
|
|
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
|
|
|
report_tensor_size("mixed_qkv_qk", mixed_qkv);
|
|
report_tensor_size("mixed_qkv_qk", mixed_qkv);
|
|
|
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
|
|
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
|
|
@@ -71,29 +68,23 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
|
|
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
|
|
|
report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
|
|
report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
|
|
|
|
|
|
|
|
- // Apply SSM convolution
|
|
|
|
|
struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
|
|
struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
|
|
|
report_tensor_size("conv_out", conv_out);
|
|
report_tensor_size("conv_out", conv_out);
|
|
|
|
|
|
|
|
- // Apply bias if provided
|
|
|
|
|
if (conv_bias) {
|
|
if (conv_bias) {
|
|
|
conv_out = ggml_add(ctx, conv_out, conv_bias);
|
|
conv_out = ggml_add(ctx, conv_out, conv_bias);
|
|
|
report_tensor_size("conv_out_bias", conv_out);
|
|
report_tensor_size("conv_out_bias", conv_out);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Apply SiLU activation
|
|
|
|
|
conv_out = ggml_silu(ctx, conv_out);
|
|
conv_out = ggml_silu(ctx, conv_out);
|
|
|
report_tensor_size("conv_out_silu", conv_out);
|
|
report_tensor_size("conv_out_silu", conv_out);
|
|
|
|
|
|
|
|
- // Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1]
|
|
|
|
|
conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, batch_size, 1);
|
|
conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, batch_size, 1);
|
|
|
report_tensor_size("conv_out_reshaped", conv_out);
|
|
report_tensor_size("conv_out_reshaped", conv_out);
|
|
|
|
|
|
|
|
- // Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1]
|
|
|
|
|
conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
|
|
conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
|
|
|
report_tensor_size("conv_out_transposed", conv_out);
|
|
report_tensor_size("conv_out_transposed", conv_out);
|
|
|
|
|
|
|
|
- // q projection view
|
|
|
|
|
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
|
|
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
|
|
|
S_k, // ne0
|
|
S_k, // ne0
|
|
|
H_k, // ne1
|
|
H_k, // ne1
|
|
@@ -132,7 +123,6 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
);
|
|
);
|
|
|
report_tensor_size("v_conv_view", v_conv);
|
|
report_tensor_size("v_conv_view", v_conv);
|
|
|
|
|
|
|
|
- // Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1]
|
|
|
|
|
q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
|
|
q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
|
|
|
report_tensor_size("q_conv_permuted", q_conv);
|
|
report_tensor_size("q_conv_permuted", q_conv);
|
|
|
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
|
|
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
|
|
@@ -147,29 +137,23 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
|
|
v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
|
|
|
report_tensor_size("v_conv_reshaped", v_conv);
|
|
report_tensor_size("v_conv_reshaped", v_conv);
|
|
|
|
|
|
|
|
- // NOW we repeat query and key to match value head dimensions if needed (after convolution)
|
|
|
|
|
struct ggml_tensor * q_broadcast = q_conv;
|
|
struct ggml_tensor * q_broadcast = q_conv;
|
|
|
struct ggml_tensor * k_broadcast = k_conv;
|
|
struct ggml_tensor * k_broadcast = k_conv;
|
|
|
|
|
|
|
|
if (H_k != H_v) {
|
|
if (H_k != H_v) {
|
|
|
- // Calculate the repeat factor: H_v / H_k
|
|
|
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
|
int64_t repeat_factor = H_v / H_k;
|
|
int64_t repeat_factor = H_v / H_k;
|
|
|
|
|
|
|
|
- // Repeat query and key along the head dimension
|
|
|
|
|
- // First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens]
|
|
|
|
|
q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
|
|
q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
|
|
|
report_tensor_size("q_broadcast_reshape1", q_broadcast);
|
|
report_tensor_size("q_broadcast_reshape1", q_broadcast);
|
|
|
k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
|
|
k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
|
|
|
report_tensor_size("k_broadcast_reshape1", k_broadcast);
|
|
report_tensor_size("k_broadcast_reshape1", k_broadcast);
|
|
|
|
|
|
|
|
- // Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens]
|
|
|
|
|
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
|
|
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
|
|
|
report_tensor_size("q_broadcast_repeat", q_broadcast);
|
|
report_tensor_size("q_broadcast_repeat", q_broadcast);
|
|
|
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
|
|
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
|
|
|
report_tensor_size("k_broadcast_repeat", k_broadcast);
|
|
report_tensor_size("k_broadcast_repeat", k_broadcast);
|
|
|
|
|
|
|
|
- // Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1]
|
|
|
|
|
q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
|
report_tensor_size("q_broadcast_reshape2", q_broadcast);
|
|
report_tensor_size("q_broadcast_reshape2", q_broadcast);
|
|
|
k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
@@ -180,7 +164,6 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
report_tensor_size("v_reshape", v_reshape);
|
|
report_tensor_size("v_reshape", v_reshape);
|
|
|
struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
|
|
struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
|
|
|
report_tensor_size("v_broadcast", v_broadcast);
|
|
report_tensor_size("v_broadcast", v_broadcast);
|
|
|
- // g already has correct dimensions [S_v, H_v, n_tokens, batch_size], no need to reshape
|
|
|
|
|
struct ggml_tensor * g_reshape = g;
|
|
struct ggml_tensor * g_reshape = g;
|
|
|
report_tensor_size("g_reshape", g_reshape);
|
|
report_tensor_size("g_reshape", g_reshape);
|
|
|
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
|
|
@@ -191,13 +174,9 @@ struct ggml_tensor * ggml_delta_net(
|
|
|
report_tensor_size("beta_reshape", beta_reshape);
|
|
report_tensor_size("beta_reshape", beta_reshape);
|
|
|
struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
|
|
struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
|
|
|
report_tensor_size("beta_broadcast", beta_broadcast);
|
|
report_tensor_size("beta_broadcast", beta_broadcast);
|
|
|
- // The state should be repeated along the sequence dimension only
|
|
|
|
|
- // Original state: [S_v, S_v, H_v, 1] -> should become [S_v, S_v, H_v, n_seqs]
|
|
|
|
|
- // Use ggml_cont to ensure the state is contiguous, not ggml_repeat_4d which would repeat along all dimensions
|
|
|
|
|
struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
|
|
struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
|
|
|
report_tensor_size("state_broadcast", state_broadcast);
|
|
report_tensor_size("state_broadcast", state_broadcast);
|
|
|
|
|
|
|
|
- // Call tensor-level kernel with convolved and processed tensors
|
|
|
|
|
return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
|
|
return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -220,7 +199,6 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
report_tensor_size("beta_input", beta);
|
|
report_tensor_size("beta_input", beta);
|
|
|
report_tensor_size("state_input", state);
|
|
report_tensor_size("state_input", state);
|
|
|
|
|
|
|
|
- // Validate dimensions
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(q));
|
|
GGML_ASSERT(ggml_is_contiguous(q));
|
|
|
GGML_ASSERT(ggml_is_contiguous(k));
|
|
GGML_ASSERT(ggml_is_contiguous(k));
|
|
|
GGML_ASSERT(ggml_is_contiguous(v));
|
|
GGML_ASSERT(ggml_is_contiguous(v));
|
|
@@ -228,17 +206,16 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
GGML_ASSERT(ggml_is_contiguous(beta));
|
|
GGML_ASSERT(ggml_is_contiguous(beta));
|
|
|
GGML_ASSERT(ggml_is_contiguous(state));
|
|
GGML_ASSERT(ggml_is_contiguous(state));
|
|
|
|
|
|
|
|
- const int64_t S_k = q->ne[0]; // head dimension for q/k
|
|
|
|
|
- const int64_t H_k = q->ne[1]; // number of heads (already processed to match v)
|
|
|
|
|
|
|
+ const int64_t S_k = q->ne[0];
|
|
|
|
|
+ const int64_t H_k = q->ne[1];
|
|
|
const int64_t n_tokens = q->ne[2];
|
|
const int64_t n_tokens = q->ne[2];
|
|
|
- const int64_t batch_size = q->ne[3]; // batch size, not n_seqs
|
|
|
|
|
|
|
+ const int64_t batch_size = q->ne[3];
|
|
|
|
|
|
|
|
- const int64_t S_v = v->ne[0]; // head dimension for v
|
|
|
|
|
- const int64_t H_v = v->ne[1]; // head dimension for v
|
|
|
|
|
|
|
+ const int64_t S_v = v->ne[0];
|
|
|
|
|
+ const int64_t H_v = v->ne[1];
|
|
|
|
|
|
|
|
GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
|
|
GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
|
|
|
|
|
|
|
|
- // Validate dimensions - match Python implementation layout
|
|
|
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
|
|
|
GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
|
|
GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
|
|
|
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
|
|
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
|
|
@@ -250,13 +227,9 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
|
|
|
|
|
struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v * S_v, H_v, 1, n_tokens);
|
|
struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v * S_v, H_v, 1, n_tokens);
|
|
|
|
|
|
|
|
- // Copy initial state to new_state
|
|
|
|
|
new_state = ggml_cpy(ctx, state, new_state);
|
|
new_state = ggml_cpy(ctx, state, new_state);
|
|
|
report_tensor_size("new_state_copied", new_state);
|
|
report_tensor_size("new_state_copied", new_state);
|
|
|
|
|
|
|
|
- // Process all sequences and heads together using tensor operations
|
|
|
|
|
-
|
|
|
|
|
- // Apply L2 normalization if requested - per head, token, and sequence
|
|
|
|
|
if (use_qk_l2norm) {
|
|
if (use_qk_l2norm) {
|
|
|
q = ggml_l2_norm(ctx, q, 1e-6f);
|
|
q = ggml_l2_norm(ctx, q, 1e-6f);
|
|
|
report_tensor_size("q_l2norm", q);
|
|
report_tensor_size("q_l2norm", q);
|
|
@@ -264,20 +237,13 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
report_tensor_size("k_l2norm", k);
|
|
report_tensor_size("k_l2norm", k);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Apply scaling to query - across all tokens, sequences and heads
|
|
|
|
|
q = ggml_scale(ctx, q, scale);
|
|
q = ggml_scale(ctx, q, scale);
|
|
|
report_tensor_size("q_scaled", q);
|
|
report_tensor_size("q_scaled", q);
|
|
|
|
|
|
|
|
- // Process the gated delta rule using tensor operations
|
|
|
|
|
-
|
|
|
|
|
- // Reshape for matrix operations: [S_v, S_v, H_v, 1] -> [S_v * S_v, H_v]
|
|
|
|
|
struct ggml_tensor * state_flat = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
|
|
struct ggml_tensor * state_flat = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
|
|
|
report_tensor_size("state_flat", state_flat);
|
|
report_tensor_size("state_flat", state_flat);
|
|
|
|
|
|
|
|
- // Process each token sequentially due to recurrent nature
|
|
|
|
|
for (int64_t t = 0; t < n_tokens; ++t) {
|
|
for (int64_t t = 0; t < n_tokens; ++t) {
|
|
|
- // Extract current token's data across all batches and heads
|
|
|
|
|
- // q, k, v are [S_k, H_k, n_tokens, batch_size] layout in GGML
|
|
|
|
|
struct ggml_tensor * q_t = ggml_view_3d(ctx, q, S_k, H_k, batch_size,
|
|
struct ggml_tensor * q_t = ggml_view_3d(ctx, q, S_k, H_k, batch_size,
|
|
|
q->nb[1], q->nb[2], t * q->nb[2]);
|
|
q->nb[1], q->nb[2], t * q->nb[2]);
|
|
|
report_tensor_size("q_t_view", q_t);
|
|
report_tensor_size("q_t_view", q_t);
|
|
@@ -291,403 +257,222 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
beta->nb[1], beta->nb[2], t * beta->nb[2]);
|
|
beta->nb[1], beta->nb[2], t * beta->nb[2]);
|
|
|
report_tensor_size("beta_t_view", beta_t);
|
|
report_tensor_size("beta_t_view", beta_t);
|
|
|
|
|
|
|
|
- // Simplified approach: follow Python implementation exactly
|
|
|
|
|
- // In Python: kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
|
- // This means: for each batch and head, multiply state by k_t and sum over the last dimension
|
|
|
|
|
-
|
|
|
|
|
- // First, reshape tensors to match GGML layout for head-wise processing
|
|
|
|
|
- // q_t: [S_k, H_k, batch_size] -> reshape to [S_k, H_k * batch_size]
|
|
|
|
|
struct ggml_tensor * q_t_reshaped = ggml_reshape_2d(ctx, q_t, S_k, H_k * batch_size);
|
|
struct ggml_tensor * q_t_reshaped = ggml_reshape_2d(ctx, q_t, S_k, H_k * batch_size);
|
|
|
report_tensor_size("q_t_reshaped", q_t_reshaped);
|
|
report_tensor_size("q_t_reshaped", q_t_reshaped);
|
|
|
|
|
|
|
|
- // k_t: [S_k, H_k, batch_size] -> reshape to [S_k, H_k * batch_size]
|
|
|
|
|
struct ggml_tensor * k_t_reshaped = ggml_reshape_2d(ctx, k_t, S_k, H_k * batch_size);
|
|
struct ggml_tensor * k_t_reshaped = ggml_reshape_2d(ctx, k_t, S_k, H_k * batch_size);
|
|
|
report_tensor_size("k_t_reshaped", k_t_reshaped);
|
|
report_tensor_size("k_t_reshaped", k_t_reshaped);
|
|
|
|
|
|
|
|
- // v_t: [S_v, H_v, batch_size] -> reshape to [S_v, H_v * batch_size]
|
|
|
|
|
struct ggml_tensor * v_t_reshaped = ggml_reshape_2d(ctx, v_t, S_v, H_v * batch_size);
|
|
struct ggml_tensor * v_t_reshaped = ggml_reshape_2d(ctx, v_t, S_v, H_v * batch_size);
|
|
|
report_tensor_size("v_t_reshaped", v_t_reshaped);
|
|
report_tensor_size("v_t_reshaped", v_t_reshaped);
|
|
|
|
|
|
|
|
- // beta_t: [1, H_v, batch_size] -> reshape to [1, H_v * batch_size]
|
|
|
|
|
struct ggml_tensor * beta_t_reshaped = ggml_reshape_2d(ctx, beta_t, 1, H_v * batch_size);
|
|
struct ggml_tensor * beta_t_reshaped = ggml_reshape_2d(ctx, beta_t, 1, H_v * batch_size);
|
|
|
report_tensor_size("beta_t_reshaped", beta_t_reshaped);
|
|
report_tensor_size("beta_t_reshaped", beta_t_reshaped);
|
|
|
|
|
|
|
|
- // Handle head dimension mismatch - repeat k_t if needed
|
|
|
|
|
struct ggml_tensor * k_t_final = k_t_reshaped;
|
|
struct ggml_tensor * k_t_final = k_t_reshaped;
|
|
|
if (H_k != H_v) {
|
|
if (H_k != H_v) {
|
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
|
|
|
|
|
|
- // Reshape k_t to separate head and batch dimensions: [S_k, H_k, batch_size, 1]
|
|
|
|
|
struct ggml_tensor * k_t_4d = ggml_reshape_4d(ctx, k_t_reshaped, S_k, H_k, 1, batch_size);
|
|
struct ggml_tensor * k_t_4d = ggml_reshape_4d(ctx, k_t_reshaped, S_k, H_k, 1, batch_size);
|
|
|
report_tensor_size("k_t_4d", k_t_4d);
|
|
report_tensor_size("k_t_4d", k_t_4d);
|
|
|
|
|
|
|
|
- // Repeat along head dimension: [S_k, H_v, batch_size, 1]
|
|
|
|
|
k_t_final = ggml_repeat_4d(ctx, k_t_4d, S_k, H_v, 1, batch_size);
|
|
k_t_final = ggml_repeat_4d(ctx, k_t_4d, S_k, H_v, 1, batch_size);
|
|
|
report_tensor_size("k_t_final_repeated", k_t_final);
|
|
report_tensor_size("k_t_final_repeated", k_t_final);
|
|
|
|
|
|
|
|
- // Reshape back to 2D: [S_k, H_v * batch_size]
|
|
|
|
|
k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
|
|
k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
|
|
|
report_tensor_size("k_t_final_2d", k_t_final);
|
|
report_tensor_size("k_t_final_2d", k_t_final);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Simplified kv_mem computation: state @ k_t^T for each head
|
|
|
|
|
- // For now, let's use a simpler approach that matches the Python logic more closely
|
|
|
|
|
- // kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
|
-
|
|
|
|
|
- // Reshape state to [S_v * S_v, H_v] for easier processing
|
|
|
|
|
struct ggml_tensor * state_2d = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
|
|
struct ggml_tensor * state_2d = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
|
|
|
report_tensor_size("state_2d", state_2d);
|
|
report_tensor_size("state_2d", state_2d);
|
|
|
|
|
|
|
|
- // The state is already in the correct format for matrix operations
|
|
|
|
|
struct ggml_tensor * state_t = state_2d;
|
|
struct ggml_tensor * state_t = state_2d;
|
|
|
report_tensor_size("state_t", state_t);
|
|
report_tensor_size("state_t", state_t);
|
|
|
|
|
|
|
|
- // Simple kv_mem computation for this token
|
|
|
|
|
- // kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
|
- // In GGML, we need to implement: (state_t * k_t_broadcast).sum(dim=1)
|
|
|
|
|
- // state_t: [S_v * S_v, H_v], k_t_final: [S_k, H_v * batch_size]
|
|
|
|
|
-
|
|
|
|
|
- // For the correct matrix multiplication, we need:
|
|
|
|
|
- // state_t: [S_v * S_v, H_v]
|
|
|
|
|
- // k_t_final: [S_k, H_v * batch_size]
|
|
|
|
|
- // We want: state_t @ k_t_transposed where k_t_transposed is [H_v * batch_size, S_k]
|
|
|
|
|
-
|
|
|
|
|
- // But first, let's check if we can do a simpler approach
|
|
|
|
|
- // Since we have H_v = 16 and batch_size = 1, we have:
|
|
|
|
|
- // state_t: [16384, 16] and k_t_final: [128, 16]
|
|
|
|
|
-
|
|
|
|
|
- // For matrix multiplication, we need: [16384, 16] @ [16, 128] = [16384, 128]
|
|
|
|
|
- // So we need to transpose k_t_final to get [16, 128]
|
|
|
|
|
-
|
|
|
|
|
- // For GGML matrix multiplication, we need to satisfy ggml_can_mul_mat requirements:
|
|
|
|
|
- // t0->ne[0] == t1->ne[0] (first dimensions must be equal)
|
|
|
|
|
- // t1->ne[2]%t0->ne[2] == 0 (broadcastable along 3rd dimension)
|
|
|
|
|
- // t1->ne[3]%t0->ne[3] == 0 (broadcastable along 4th dimension)
|
|
|
|
|
-
|
|
|
|
|
- // We need to reshape state_t from [S_v * S_v, H_v, 1, 1] to [H_v, S_v * S_v, 1, 1]
|
|
|
|
|
- // and k_t_final from [S_k, H_v * batch_size] to [H_v, S_k, 1, 1]
|
|
|
|
|
-
|
|
|
|
|
- // First, transpose state_t to get [H_v, S_v * S_v, 1, 1]
|
|
|
|
|
struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
|
|
struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
|
|
|
report_tensor_size("state_t_transposed", state_t_transposed);
|
|
report_tensor_size("state_t_transposed", state_t_transposed);
|
|
|
|
|
|
|
|
- // Reshape k_t_final from [S_k, H_v * batch_size] to [H_v, S_k, 1, 1]
|
|
|
|
|
struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
|
|
struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
|
|
|
report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
|
|
report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
|
|
|
|
|
|
|
|
- // Now we can do matrix multiplication: k_t_final_reshaped^T @ state_t_transposed^T
|
|
|
|
|
- // But GGML doesn't allow transposed first argument, so we need to swap the order
|
|
|
|
|
- // and transpose the result if needed
|
|
|
|
|
struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
|
|
struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
|
|
|
report_tensor_size("kv_mem", kv_mem);
|
|
report_tensor_size("kv_mem", kv_mem);
|
|
|
|
|
|
|
|
- // Compute delta = (v_t - kv_mem) * beta_t
|
|
|
|
|
- // kv_mem: [batch_size, S_v] (result of state @ k_t^T)
|
|
|
|
|
- // v_t: [batch_size, H_v, S_v] -> reshape to [batch_size * H_v, S_v]
|
|
|
|
|
- // beta_t: [batch_size, H_v, 1] -> reshape to [batch_size * H_v, 1]
|
|
|
|
|
-
|
|
|
|
|
- // Handle head dimension mismatch for v_t and beta_t
|
|
|
|
|
struct ggml_tensor * v_t_final = v_t_reshaped;
|
|
struct ggml_tensor * v_t_final = v_t_reshaped;
|
|
|
struct ggml_tensor * beta_t_final = beta_t_reshaped;
|
|
struct ggml_tensor * beta_t_final = beta_t_reshaped;
|
|
|
|
|
|
|
|
if (H_k != H_v) {
|
|
if (H_k != H_v) {
|
|
|
- // Repeat v_t and beta_t along head dimension to match H_v
|
|
|
|
|
- // v_t: [S_v, H_k, batch_size] -> [S_v, H_k, batch_size, 1] -> repeat -> [S_v, H_v, batch_size, 1]
|
|
|
|
|
struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
|
|
struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
|
|
|
struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
|
|
struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
|
|
|
v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
|
|
v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
|
|
|
|
|
|
|
|
- // beta_t: [1, H_k, batch_size] -> [1, H_k, batch_size, 1] -> repeat -> [1, H_v, batch_size, 1]
|
|
|
|
|
struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
|
|
struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
|
|
|
struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
|
|
struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
|
|
|
beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
|
|
beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Ensure kv_mem has correct dimensions for subtraction
|
|
|
|
|
- // kv_mem dimensions from trace: [128, 16384, 1, 1]
|
|
|
|
|
- // We need to reshape it to match v_t_final: [128, 16, 1, 1]
|
|
|
|
|
-
|
|
|
|
|
- // First, let's reshape kv_mem to the correct dimensions
|
|
|
|
|
struct ggml_tensor * kv_mem_reshaped;
|
|
struct ggml_tensor * kv_mem_reshaped;
|
|
|
if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
|
|
if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
|
|
|
- // Perfect match
|
|
|
|
|
kv_mem_reshaped = kv_mem;
|
|
kv_mem_reshaped = kv_mem;
|
|
|
} else if (kv_mem->ne[0] == S_v) {
|
|
} else if (kv_mem->ne[0] == S_v) {
|
|
|
- // We have the right first dimension, need to fix the second dimension
|
|
|
|
|
kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
|
|
kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
|
|
|
} else {
|
|
} else {
|
|
|
- // Handle other dimension mismatches
|
|
|
|
|
report_tensor_size("kv_mem_before_reshape", kv_mem);
|
|
report_tensor_size("kv_mem_before_reshape", kv_mem);
|
|
|
kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
|
|
kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
|
|
|
}
|
|
}
|
|
|
kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
|
|
kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
|
|
|
report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
|
|
report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
|
|
|
|
|
|
|
|
- // Now ensure kv_mem_reshaped has the same dimensions as v_t_final
|
|
|
|
|
struct ggml_tensor * kv_mem_final;
|
|
struct ggml_tensor * kv_mem_final;
|
|
|
if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
|
|
if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
|
|
|
kv_mem_final = kv_mem_reshaped;
|
|
kv_mem_final = kv_mem_reshaped;
|
|
|
} else {
|
|
} else {
|
|
|
- // Use repeat to match dimensions if they're compatible
|
|
|
|
|
kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
|
|
kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
|
|
|
}
|
|
}
|
|
|
report_tensor_size("kv_mem_final", kv_mem_final);
|
|
report_tensor_size("kv_mem_final", kv_mem_final);
|
|
|
|
|
|
|
|
- // Compute delta = (v_t - kv_mem) * beta_t
|
|
|
|
|
struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
|
|
struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
|
|
|
report_tensor_size("delta", delta);
|
|
report_tensor_size("delta", delta);
|
|
|
|
|
|
|
|
- // Update state: state = state + outer(k_t, delta)
|
|
|
|
|
struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
|
|
struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
|
|
|
report_tensor_size("delta_reshaped", delta_reshaped);
|
|
report_tensor_size("delta_reshaped", delta_reshaped);
|
|
|
|
|
|
|
|
- // Handle the outer product for all heads and batches
|
|
|
|
|
- // We need to compute outer(k_t, delta) where:
|
|
|
|
|
- // k_t is [S_k * H_k, batch_size] -> reshape to [S_k, H_k * batch_size]
|
|
|
|
|
- // delta is [S_v, H_v * batch_size]
|
|
|
|
|
- // For outer product, we want k_t @ delta^T
|
|
|
|
|
-
|
|
|
|
|
- // First, handle head dimension mismatch for k_t (reuse existing k_t_final variable)
|
|
|
|
|
if (H_k == H_v) {
|
|
if (H_k == H_v) {
|
|
|
k_t_final = k_t_reshaped;
|
|
k_t_final = k_t_reshaped;
|
|
|
} else {
|
|
} else {
|
|
|
- // Need to repeat k along the head dimension to match H_v
|
|
|
|
|
int64_t repeat_factor = H_v / H_k;
|
|
int64_t repeat_factor = H_v / H_k;
|
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
|
|
|
|
|
|
- // Reshape to separate repeat dimension: [S_k, 1, H_k, batch_size]
|
|
|
|
|
k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
|
|
k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
|
|
|
report_tensor_size("k_t_final_reshape1", k_t_final);
|
|
report_tensor_size("k_t_final_reshape1", k_t_final);
|
|
|
|
|
|
|
|
- // Repeat along the new dimension: [S_k, repeat_factor, H_k, batch_size]
|
|
|
|
|
k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
|
|
k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
|
|
|
report_tensor_size("k_t_final_repeat", k_t_final);
|
|
report_tensor_size("k_t_final_repeat", k_t_final);
|
|
|
|
|
|
|
|
- // Reshape back: [S_k, H_v * batch_size]
|
|
|
|
|
k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
|
|
k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
|
|
|
report_tensor_size("k_t_final_reshape2", k_t_final);
|
|
report_tensor_size("k_t_final_reshape2", k_t_final);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Make k_t_final contiguous
|
|
|
|
|
k_t_final = ggml_cont(ctx, k_t_final);
|
|
k_t_final = ggml_cont(ctx, k_t_final);
|
|
|
report_tensor_size("k_t_final_cont", k_t_final);
|
|
report_tensor_size("k_t_final_cont", k_t_final);
|
|
|
|
|
|
|
|
- // Handle dimension mismatch between S_k and S_v
|
|
|
|
|
struct ggml_tensor * k_t_for_outer;
|
|
struct ggml_tensor * k_t_for_outer;
|
|
|
if (S_k == S_v) {
|
|
if (S_k == S_v) {
|
|
|
k_t_for_outer = k_t_final;
|
|
k_t_for_outer = k_t_final;
|
|
|
} else if (S_k < S_v) {
|
|
} else if (S_k < S_v) {
|
|
|
- // Pad k_t to match S_v
|
|
|
|
|
struct ggml_tensor * padding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v - S_k, H_v * batch_size);
|
|
struct ggml_tensor * padding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v - S_k, H_v * batch_size);
|
|
|
report_tensor_size("k_t_padding", padding);
|
|
report_tensor_size("k_t_padding", padding);
|
|
|
k_t_for_outer = ggml_concat(ctx, k_t_final, padding, 0);
|
|
k_t_for_outer = ggml_concat(ctx, k_t_final, padding, 0);
|
|
|
report_tensor_size("k_t_for_outer_padded", k_t_for_outer);
|
|
report_tensor_size("k_t_for_outer_padded", k_t_for_outer);
|
|
|
} else {
|
|
} else {
|
|
|
- // Truncate k_t to match S_v
|
|
|
|
|
k_t_for_outer = ggml_view_2d(ctx, k_t_final, S_v, H_v * batch_size, k_t_final->nb[1], 0);
|
|
k_t_for_outer = ggml_view_2d(ctx, k_t_final, S_v, H_v * batch_size, k_t_final->nb[1], 0);
|
|
|
report_tensor_size("k_t_for_outer_truncated", k_t_for_outer);
|
|
report_tensor_size("k_t_for_outer_truncated", k_t_for_outer);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Make sure k_t_for_outer is contiguous
|
|
|
|
|
k_t_for_outer = ggml_cont(ctx, k_t_for_outer);
|
|
k_t_for_outer = ggml_cont(ctx, k_t_for_outer);
|
|
|
report_tensor_size("k_t_for_outer_cont", k_t_for_outer);
|
|
report_tensor_size("k_t_for_outer_cont", k_t_for_outer);
|
|
|
|
|
|
|
|
- // Compute outer product: k_t_for_outer @ delta_reshaped^T
|
|
|
|
|
- // k_t_for_outer: [S_v, H_v * batch_size]
|
|
|
|
|
- // delta_reshaped: [S_v, H_v * batch_size]
|
|
|
|
|
- // For outer product, we want: k_t_for_outer @ delta_reshaped^T
|
|
|
|
|
-
|
|
|
|
|
- // We need to satisfy ggml_can_mul_mat requirements:
|
|
|
|
|
- // t0->ne[0] == t1->ne[0] (first dimensions must be equal)
|
|
|
|
|
- // t1->ne[2]%t0->ne[2] == 0 (broadcastable along 3rd dimension)
|
|
|
|
|
- // t1->ne[3]%t0->ne[3] == 0 (broadcastable along 4th dimension)
|
|
|
|
|
-
|
|
|
|
|
- // First, reshape k_t_for_outer to [S_v, H_v * batch_size, 1, 1]
|
|
|
|
|
struct ggml_tensor * k_t_reshaped_4d = ggml_reshape_4d(ctx, k_t_for_outer, S_v, H_v, 1, batch_size);
|
|
struct ggml_tensor * k_t_reshaped_4d = ggml_reshape_4d(ctx, k_t_for_outer, S_v, H_v, 1, batch_size);
|
|
|
report_tensor_size("k_t_reshaped_4d", k_t_reshaped_4d);
|
|
report_tensor_size("k_t_reshaped_4d", k_t_reshaped_4d);
|
|
|
|
|
|
|
|
- // Transpose delta_reshaped to get [H_v * batch_size, S_v]
|
|
|
|
|
struct ggml_tensor * delta_transposed = ggml_transpose(ctx, delta_reshaped);
|
|
struct ggml_tensor * delta_transposed = ggml_transpose(ctx, delta_reshaped);
|
|
|
report_tensor_size("delta_transposed", delta_transposed);
|
|
report_tensor_size("delta_transposed", delta_transposed);
|
|
|
|
|
|
|
|
- // Make delta_transposed contiguous before reshaping
|
|
|
|
|
delta_transposed = ggml_cont(ctx, delta_transposed);
|
|
delta_transposed = ggml_cont(ctx, delta_transposed);
|
|
|
report_tensor_size("delta_transposed_cont", delta_transposed);
|
|
report_tensor_size("delta_transposed_cont", delta_transposed);
|
|
|
|
|
|
|
|
- // Reshape delta_transposed to [H_v * batch_size, S_v, 1, 1]
|
|
|
|
|
struct ggml_tensor * delta_reshaped_4d = ggml_reshape_4d(ctx, delta_transposed, H_v, S_v, 1, batch_size);
|
|
struct ggml_tensor * delta_reshaped_4d = ggml_reshape_4d(ctx, delta_transposed, H_v, S_v, 1, batch_size);
|
|
|
report_tensor_size("delta_reshaped_4d", delta_reshaped_4d);
|
|
report_tensor_size("delta_reshaped_4d", delta_reshaped_4d);
|
|
|
|
|
|
|
|
- // For outer product k @ delta^T, we need: [S_v, H_v * batch_size] @ [H_v * batch_size, S_v] = [S_v, S_v]
|
|
|
|
|
- // But GGML requires the first dimensions to be equal for matrix multiplication
|
|
|
|
|
- // So we need to transpose the first tensor: k_t_reshaped_4d^T @ delta_reshaped_4d
|
|
|
|
|
- // [H_v * batch_size, S_v] @ [H_v * batch_size, S_v] - this won't work
|
|
|
|
|
-
|
|
|
|
|
- // Instead, we need to do: delta_reshaped_4d^T @ k_t_reshaped_4d^T
|
|
|
|
|
- // But GGML doesn't allow transposed first argument, so we need to swap the order
|
|
|
|
|
- // and transpose the result if needed
|
|
|
|
|
-
|
|
|
|
|
- // Let's do: delta_reshaped_4d^T @ k_t_reshaped_4d
|
|
|
|
|
- // [S_v, H_v * batch_size] @ [S_v, H_v * batch_size] - this won't work either
|
|
|
|
|
-
|
|
|
|
|
- // The correct approach is: k_t_reshaped_4d @ delta_reshaped_4d^T
|
|
|
|
|
- // But we need to make the first dimensions equal by transposing k_t_reshaped_4d
|
|
|
|
|
struct ggml_tensor * k_t_transposed = ggml_transpose(ctx, k_t_reshaped_4d);
|
|
struct ggml_tensor * k_t_transposed = ggml_transpose(ctx, k_t_reshaped_4d);
|
|
|
report_tensor_size("k_t_transposed", k_t_transposed);
|
|
report_tensor_size("k_t_transposed", k_t_transposed);
|
|
|
|
|
|
|
|
- // Now we can do: k_t_transposed @ delta_reshaped_4d
|
|
|
|
|
- // [H_v * batch_size, S_v] @ [H_v * batch_size, S_v] - still won't work
|
|
|
|
|
-
|
|
|
|
|
- // Let's try a different approach: use the transpose of the result
|
|
|
|
|
- // We want: k @ delta^T = (delta @ k^T)^T
|
|
|
|
|
struct ggml_tensor * temp_product = ggml_mul_mat(ctx, delta_reshaped_4d, k_t_transposed);
|
|
struct ggml_tensor * temp_product = ggml_mul_mat(ctx, delta_reshaped_4d, k_t_transposed);
|
|
|
report_tensor_size("temp_product", temp_product);
|
|
report_tensor_size("temp_product", temp_product);
|
|
|
|
|
|
|
|
- // Transpose the result to get the final outer product
|
|
|
|
|
struct ggml_tensor * outer_product_raw = ggml_transpose(ctx, temp_product);
|
|
struct ggml_tensor * outer_product_raw = ggml_transpose(ctx, temp_product);
|
|
|
report_tensor_size("outer_product_raw", outer_product_raw);
|
|
report_tensor_size("outer_product_raw", outer_product_raw);
|
|
|
|
|
|
|
|
- // Make outer_product_raw contiguous before reshaping
|
|
|
|
|
struct ggml_tensor * outer_product_cont = ggml_cont(ctx, outer_product_raw);
|
|
struct ggml_tensor * outer_product_cont = ggml_cont(ctx, outer_product_raw);
|
|
|
report_tensor_size("outer_product_cont", outer_product_cont);
|
|
report_tensor_size("outer_product_cont", outer_product_cont);
|
|
|
|
|
|
|
|
- // Reshape to 2D: [S_v, S_v]
|
|
|
|
|
struct ggml_tensor * outer_product = ggml_reshape_2d(ctx, outer_product_cont, S_v, S_v);
|
|
struct ggml_tensor * outer_product = ggml_reshape_2d(ctx, outer_product_cont, S_v, S_v);
|
|
|
report_tensor_size("outer_product", outer_product);
|
|
report_tensor_size("outer_product", outer_product);
|
|
|
|
|
|
|
|
- // Now we need to reshape outer_product to match state_flat dimensions
|
|
|
|
|
- // outer_product: [S_v, S_v] -> reshape to [S_v * S_v, H_v * batch_size]
|
|
|
|
|
struct ggml_tensor * outer_product_reshaped;
|
|
struct ggml_tensor * outer_product_reshaped;
|
|
|
if (outer_product->ne[0] == S_v && outer_product->ne[1] == S_v) {
|
|
if (outer_product->ne[0] == S_v && outer_product->ne[1] == S_v) {
|
|
|
- // Perfect match for a single head/sequence
|
|
|
|
|
outer_product_reshaped = ggml_reshape_2d(ctx, outer_product, S_v * S_v, 1);
|
|
outer_product_reshaped = ggml_reshape_2d(ctx, outer_product, S_v * S_v, 1);
|
|
|
} else {
|
|
} else {
|
|
|
- // Handle whatever dimensions we got
|
|
|
|
|
outer_product_reshaped = ggml_reshape_2d(ctx, outer_product,
|
|
outer_product_reshaped = ggml_reshape_2d(ctx, outer_product,
|
|
|
outer_product->ne[0] * outer_product->ne[1], 1);
|
|
outer_product->ne[0] * outer_product->ne[1], 1);
|
|
|
}
|
|
}
|
|
|
report_tensor_size("outer_product_reshaped", outer_product_reshaped);
|
|
report_tensor_size("outer_product_reshaped", outer_product_reshaped);
|
|
|
|
|
|
|
|
- // Repeat outer_product_reshaped to match the number of heads and batches
|
|
|
|
|
struct ggml_tensor * outer_product_repeated = ggml_repeat(ctx, outer_product_reshaped, state_flat);
|
|
struct ggml_tensor * outer_product_repeated = ggml_repeat(ctx, outer_product_reshaped, state_flat);
|
|
|
report_tensor_size("outer_product_repeated", outer_product_repeated);
|
|
report_tensor_size("outer_product_repeated", outer_product_repeated);
|
|
|
|
|
|
|
|
- // Update state
|
|
|
|
|
state_flat = ggml_add(ctx, state_flat, outer_product_repeated);
|
|
state_flat = ggml_add(ctx, state_flat, outer_product_repeated);
|
|
|
report_tensor_size("state_flat_updated", state_flat);
|
|
report_tensor_size("state_flat_updated", state_flat);
|
|
|
|
|
|
|
|
- // Compute output = current_state @ q_t^T for all heads and batches
|
|
|
|
|
- // Simplified approach: follow Python implementation more closely
|
|
|
|
|
- // In Python: output = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
|
- // This means: for each batch and head, multiply state by q_t and sum over the last dimension
|
|
|
|
|
-
|
|
|
|
|
- // First, let's work with the original q_t (already processed to match H_v)
|
|
|
|
|
struct ggml_tensor * q_t_final = q_t;
|
|
struct ggml_tensor * q_t_final = q_t;
|
|
|
report_tensor_size("q_t_final", q_t_final);
|
|
report_tensor_size("q_t_final", q_t_final);
|
|
|
|
|
|
|
|
- // Make q_t_final contiguous for matrix operations
|
|
|
|
|
q_t_final = ggml_cont(ctx, q_t_final);
|
|
q_t_final = ggml_cont(ctx, q_t_final);
|
|
|
report_tensor_size("q_t_final_cont", q_t_final);
|
|
report_tensor_size("q_t_final_cont", q_t_final);
|
|
|
|
|
|
|
|
- // For the output computation, we want: (state * q_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
|
- // This is equivalent to: state @ q_t^T where q_t is reshaped appropriately
|
|
|
|
|
-
|
|
|
|
|
- // Simple approach: reshape q_t to [S_k, H_v * batch_size] and state to [S_v * S_v, H_v * batch_size]
|
|
|
|
|
- // Then compute: state^T @ q_t
|
|
|
|
|
- // But we need to handle the GGML requirements
|
|
|
|
|
-
|
|
|
|
|
- // Make state_flat contiguous
|
|
|
|
|
struct ggml_tensor * state_flat_cont = ggml_cont(ctx, state_flat);
|
|
struct ggml_tensor * state_flat_cont = ggml_cont(ctx, state_flat);
|
|
|
report_tensor_size("state_flat_cont", state_flat_cont);
|
|
report_tensor_size("state_flat_cont", state_flat_cont);
|
|
|
|
|
|
|
|
- // Reshape q_t to [S_k, H_v * batch_size] for matrix multiplication
|
|
|
|
|
struct ggml_tensor * q_t_matrix = ggml_reshape_2d(ctx, q_t_final, S_k, H_v * batch_size);
|
|
struct ggml_tensor * q_t_matrix = ggml_reshape_2d(ctx, q_t_final, S_k, H_v * batch_size);
|
|
|
report_tensor_size("q_t_matrix", q_t_matrix);
|
|
report_tensor_size("q_t_matrix", q_t_matrix);
|
|
|
|
|
|
|
|
- // Now we want to compute: state_flat_cont^T @ q_t_matrix
|
|
|
|
|
- // state_flat_cont: [S_v * S_v, H_v * batch_size] = [16384, 16]
|
|
|
|
|
- // q_t_matrix: [S_k, H_v * batch_size] = [128, 16]
|
|
|
|
|
-
|
|
|
|
|
- // For GGML, we need: q_t_matrix^T @ state_flat_cont^T
|
|
|
|
|
- // But GGML doesn't allow transposed first argument, so we use the property: A @ B = (B^T @ A^T)^T
|
|
|
|
|
-
|
|
|
|
|
- // Transpose q_t_matrix to get [H_v * batch_size, S_k] = [16, 128]
|
|
|
|
|
struct ggml_tensor * q_t_matrix_transposed = ggml_transpose(ctx, q_t_matrix);
|
|
struct ggml_tensor * q_t_matrix_transposed = ggml_transpose(ctx, q_t_matrix);
|
|
|
report_tensor_size("q_t_matrix_transposed", q_t_matrix_transposed);
|
|
report_tensor_size("q_t_matrix_transposed", q_t_matrix_transposed);
|
|
|
|
|
|
|
|
- // Transpose state_flat_cont to get [H_v * batch_size, S_v * S_v] = [16, 16384]
|
|
|
|
|
struct ggml_tensor * state_flat_transposed = ggml_transpose(ctx, state_flat_cont);
|
|
struct ggml_tensor * state_flat_transposed = ggml_transpose(ctx, state_flat_cont);
|
|
|
report_tensor_size("state_flat_transposed", state_flat_transposed);
|
|
report_tensor_size("state_flat_transposed", state_flat_transposed);
|
|
|
|
|
|
|
|
- // Now we can do: q_t_matrix_transposed @ state_flat_transposed
|
|
|
|
|
- // [16, 128] @ [16, 16384] - this won't work because first dimensions don't match
|
|
|
|
|
-
|
|
|
|
|
- // Instead, let's do: state_flat_transposed^T @ q_t_matrix_transposed^T
|
|
|
|
|
- // But we need to transpose both again
|
|
|
|
|
struct ggml_tensor * q_t_matrix_final = ggml_transpose(ctx, q_t_matrix_transposed);
|
|
struct ggml_tensor * q_t_matrix_final = ggml_transpose(ctx, q_t_matrix_transposed);
|
|
|
report_tensor_size("q_t_matrix_final", q_t_matrix_final);
|
|
report_tensor_size("q_t_matrix_final", q_t_matrix_final);
|
|
|
|
|
|
|
|
struct ggml_tensor * state_flat_final = ggml_transpose(ctx, state_flat_transposed);
|
|
struct ggml_tensor * state_flat_final = ggml_transpose(ctx, state_flat_transposed);
|
|
|
report_tensor_size("state_flat_final", state_flat_final);
|
|
report_tensor_size("state_flat_final", state_flat_final);
|
|
|
|
|
|
|
|
- // Now we can do: q_t_matrix_final @ state_flat_final
|
|
|
|
|
- // [128, 16] @ [16384, 16] - this won't work either
|
|
|
|
|
-
|
|
|
|
|
- // Let me try a different approach: use element-wise multiplication and sum
|
|
|
|
|
- // We want: (state * q_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
|
-
|
|
|
|
|
- // First, reshape q_t to broadcast with state
|
|
|
|
|
struct ggml_tensor * q_t_broadcast = ggml_repeat(ctx, q_t_final, state_flat_cont);
|
|
struct ggml_tensor * q_t_broadcast = ggml_repeat(ctx, q_t_final, state_flat_cont);
|
|
|
report_tensor_size("q_t_broadcast", q_t_broadcast);
|
|
report_tensor_size("q_t_broadcast", q_t_broadcast);
|
|
|
|
|
|
|
|
- // Element-wise multiplication
|
|
|
|
|
struct ggml_tensor * state_q_product = ggml_mul(ctx, state_flat_cont, q_t_broadcast);
|
|
struct ggml_tensor * state_q_product = ggml_mul(ctx, state_flat_cont, q_t_broadcast);
|
|
|
report_tensor_size("state_q_product", state_q_product);
|
|
report_tensor_size("state_q_product", state_q_product);
|
|
|
|
|
|
|
|
- // Let's reshape to separate the dimensions we want to sum over
|
|
|
|
|
-
|
|
|
|
|
- // Reshape state_q_product to [S_v * S_v, H_v, batch_size]
|
|
|
|
|
struct ggml_tensor * state_q_3d = ggml_reshape_3d(ctx, state_q_product, S_v * S_v, H_v, batch_size);
|
|
struct ggml_tensor * state_q_3d = ggml_reshape_3d(ctx, state_q_product, S_v * S_v, H_v, batch_size);
|
|
|
report_tensor_size("state_q_3d", state_q_3d);
|
|
report_tensor_size("state_q_3d", state_q_3d);
|
|
|
- // Ensure contiguous layout so byte-strides are consistent for subsequent views/slices.
|
|
|
|
|
state_q_3d = ggml_cont(ctx, state_q_3d);
|
|
state_q_3d = ggml_cont(ctx, state_q_3d);
|
|
|
report_tensor_size("state_q_3d_cont", state_q_3d);
|
|
report_tensor_size("state_q_3d_cont", state_q_3d);
|
|
|
|
|
|
|
|
- // Sum over the H_v dimension (axis 1)
|
|
|
|
|
- // Create a proper ones vector: ggml_new_tensor_1d already creates a zero-filled tensor,
|
|
|
|
|
- // so ggml_exp on it will produce ones (exp(0) = 1).
|
|
|
|
|
struct ggml_tensor * ones_vector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, H_v);
|
|
struct ggml_tensor * ones_vector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, H_v);
|
|
|
ones_vector = ggml_exp(ctx, ones_vector); // exp(0) = 1
|
|
ones_vector = ggml_exp(ctx, ones_vector); // exp(0) = 1
|
|
|
report_tensor_size("ones_vector", ones_vector);
|
|
report_tensor_size("ones_vector", ones_vector);
|
|
|
|
|
|
|
|
- // Reshape to [H_v, 1] for matrix multiplication
|
|
|
|
|
struct ggml_tensor * ones_col = ggml_reshape_2d(ctx, ones_vector, H_v, 1);
|
|
struct ggml_tensor * ones_col = ggml_reshape_2d(ctx, ones_vector, H_v, 1);
|
|
|
report_tensor_size("ones_col", ones_col);
|
|
report_tensor_size("ones_col", ones_col);
|
|
|
|
|
|
|
|
- // Prepare per-batch results
|
|
|
|
|
struct ggml_tensor * output_parts[batch_size];
|
|
struct ggml_tensor * output_parts[batch_size];
|
|
|
for (int64_t b = 0; b < batch_size; b++) {
|
|
for (int64_t b = 0; b < batch_size; b++) {
|
|
|
- // Extract slice for this batch: [S_v * S_v, H_v]
|
|
|
|
|
- // Use the contiguous state_q_3d so nb and offsets are reliable.
|
|
|
|
|
struct ggml_tensor * batch_slice = ggml_view_3d(ctx, state_q_3d, S_v * S_v, H_v, 1,
|
|
struct ggml_tensor * batch_slice = ggml_view_3d(ctx, state_q_3d, S_v * S_v, H_v, 1,
|
|
|
state_q_3d->nb[1], state_q_3d->nb[2], b * state_q_3d->nb[2]);
|
|
state_q_3d->nb[1], state_q_3d->nb[2], b * state_q_3d->nb[2]);
|
|
|
batch_slice = ggml_cont(ctx, batch_slice);
|
|
batch_slice = ggml_cont(ctx, batch_slice);
|
|
|
report_tensor_size("batch_slice", batch_slice);
|
|
report_tensor_size("batch_slice", batch_slice);
|
|
|
|
|
|
|
|
- // Multiply by ones and sum across H_v:
|
|
|
|
|
- // ones_col: [H_v, 1], batch_slice^T: [H_v, S_v * S_v] -> ones_col @ batch_slice^T = [1, S_v * S_v]
|
|
|
|
|
struct ggml_tensor * batch_slice_t = ggml_transpose(ctx, batch_slice);
|
|
struct ggml_tensor * batch_slice_t = ggml_transpose(ctx, batch_slice);
|
|
|
report_tensor_size("batch_slice_t", batch_slice_t);
|
|
report_tensor_size("batch_slice_t", batch_slice_t);
|
|
|
struct ggml_tensor * batch_sum = ggml_mul_mat(ctx, ones_col, batch_slice_t);
|
|
struct ggml_tensor * batch_sum = ggml_mul_mat(ctx, ones_col, batch_slice_t);
|
|
|
report_tensor_size("batch_sum", batch_sum);
|
|
report_tensor_size("batch_sum", batch_sum);
|
|
|
|
|
|
|
|
- // Reshape [1, S_v*S_v] -> [S_v, S_v]
|
|
|
|
|
struct ggml_tensor * batch_result = ggml_reshape_2d(ctx, batch_sum, S_v, S_v);
|
|
struct ggml_tensor * batch_result = ggml_reshape_2d(ctx, batch_sum, S_v, S_v);
|
|
|
report_tensor_size("batch_result", batch_result);
|
|
report_tensor_size("batch_result", batch_result);
|
|
|
output_parts[b] = batch_result;
|
|
output_parts[b] = batch_result;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Concatenate results from all batches into [S_v * S_v, batch_size]
|
|
|
|
|
struct ggml_tensor * output_concat = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v * S_v, batch_size);
|
|
struct ggml_tensor * output_concat = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v * S_v, batch_size);
|
|
|
for (int64_t b = 0; b < batch_size; b++) {
|
|
for (int64_t b = 0; b < batch_size; b++) {
|
|
|
struct ggml_tensor * batch_output = ggml_view_2d(ctx, output_concat, S_v * S_v, 1,
|
|
struct ggml_tensor * batch_output = ggml_view_2d(ctx, output_concat, S_v * S_v, 1,
|
|
@@ -695,12 +480,10 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
batch_output = ggml_cpy(ctx, output_parts[b], batch_output);
|
|
batch_output = ggml_cpy(ctx, output_parts[b], batch_output);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Reshape concatenated result to [S_v, S_v] for this token (batch_size typically 1)
|
|
|
|
|
struct ggml_tensor * output_t_reshaped = ggml_reshape_2d(ctx, output_concat, S_v, S_v);
|
|
struct ggml_tensor * output_t_reshaped = ggml_reshape_2d(ctx, output_concat, S_v, S_v);
|
|
|
struct ggml_tensor * output_t = ggml_cont(ctx, output_t_reshaped);
|
|
struct ggml_tensor * output_t = ggml_cont(ctx, output_t_reshaped);
|
|
|
report_tensor_size("output_t", output_t);
|
|
report_tensor_size("output_t", output_t);
|
|
|
|
|
|
|
|
- // Store output for this token
|
|
|
|
|
struct ggml_tensor * output_slice = ggml_view_3d(ctx, output, S_v, S_v, batch_size,
|
|
struct ggml_tensor * output_slice = ggml_view_3d(ctx, output, S_v, S_v, batch_size,
|
|
|
output->nb[1], output->nb[2], t * output->nb[2]);
|
|
output->nb[1], output->nb[2], t * output->nb[2]);
|
|
|
report_tensor_size("output_slice", output_slice);
|
|
report_tensor_size("output_slice", output_slice);
|
|
@@ -712,4 +495,3 @@ struct ggml_tensor * ggml_delta_net_op(
|
|
|
report_tensor_size("result_final", result);
|
|
report_tensor_size("result_final", result);
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
|
-// ggml_rwkv_wkv7
|
|
|