Bläddra i källkod

Corrected Implementation of Qwen3Next Support

Çetin 3 veckor sedan
förälder
incheckning
45ada635f0

+ 13 - 0
ggml/include/ggml.h

@@ -549,6 +549,7 @@ extern "C" {
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
         GGML_OP_GATED_LINEAR_ATTN,
+        GGML_OP_GATED_DELTA_RULE,
         GGML_OP_RWKV_WKV7,
         GGML_OP_SOLVE_TRI,
 
@@ -2429,6 +2430,18 @@ extern "C" {
             struct ggml_tensor  * state,
             float scale);
 
+    // Gated Delta Rule (GDN) - concatenated output + updated state
+    GGML_API struct ggml_tensor * ggml_gated_delta_rule(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * beta,
+            struct ggml_tensor  * state,
+            float                 scale,
+            float                 eps);
+
     GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
             struct ggml_context * ctx,
             struct ggml_tensor  * r,

+ 4 - 0
ggml/src/ggml-cpu/ggml-cpu.c

@@ -2010,6 +2010,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gla(params, tensor);
             } break;
+        case GGML_OP_GATED_DELTA_RULE:
+            {
+                ggml_compute_forward_gated_delta_rule(params, tensor);
+            } break;
         case GGML_OP_RWKV_WKV7:
             {
                 ggml_compute_forward_rwkv_wkv7(params, tensor);

+ 171 - 0
ggml/src/ggml-cpu/ops.cpp

@@ -4359,6 +4359,7 @@ static void ggml_compute_forward_scale_f32(
     }
 }
 
+
 void ggml_compute_forward_scale(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
@@ -9800,6 +9801,176 @@ void ggml_compute_forward_gla(
     }
 }
 
+// ggml_compute_forward_gated_delta_rule
+
+static inline float ggml_compute_sigmoid_f32(float x) {
+    // numerically stable sigmoid
+    if (x >= 0.0f) {
+        const float z = expf(-x);
+        return 1.0f / (1.0f + z);
+    } else {
+        const float z = expf(x);
+        return z / (1.0f + z);
+    }
+}
+
+static void ggml_compute_forward_gated_delta_rule_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * q    = dst->src[0]; // {D, H, T, B}
+    const ggml_tensor * k    = dst->src[1]; // {D, H, T, B}
+    const ggml_tensor * v    = dst->src[2]; // {D, H, T, B}
+    const ggml_tensor * g    = dst->src[3]; // {H, T, B}
+    const ggml_tensor * beta = dst->src[4]; // {H, T, B}
+    const ggml_tensor * s    = dst->src[5]; // {D, D, H, B}
+
+    const int64_t D = q->ne[0];
+    const int64_t H = q->ne[1];
+    const int64_t T = q->ne[2];
+    const int64_t B = q->ne[3];
+
+    GGML_ASSERT(k->ne[0] == D && k->ne[1] == H && k->ne[2] == T && k->ne[3] == B);
+    GGML_ASSERT(v->ne[0] == D && v->ne[1] == H && v->ne[2] == T && v->ne[3] == B);
+    GGML_ASSERT(g->ne[0] == H && g->ne[1] == T && g->ne[2] == B);
+    GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == T && beta->ne[2] == B);
+    GGML_ASSERT(s->ne[0] == D && s->ne[1] == D && s->ne[2] == H && s->ne[3] == B);
+    GGML_ASSERT(s->type == GGML_TYPE_F32 || s->type == GGML_TYPE_F16);
+    GGML_ASSERT(s->type == GGML_TYPE_F32 || s->type == GGML_TYPE_F16);
+
+    const float q_scale = ggml_get_op_params_f32(dst, 0);
+    const float eps     = ggml_get_op_params_f32(dst, 1);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t n_heads_total = B * H;
+    const int64_t dh = (n_heads_total + nth - 1) / nth;
+    const int64_t ih0 = dh * ith;
+    const int64_t ih1 = MIN(ih0 + dh, n_heads_total);
+
+    float * out = (float *) dst->data;
+    float * out_state = out + ggml_nelements(v);
+
+    const char * q_data    = (const char *) q->data;
+    const char * k_data    = (const char *) k->data;
+    const char * v_data    = (const char *) v->data;
+    const char * g_data    = (const char *) g->data;
+    const char * beta_data = (const char *) beta->data;
+    const void * s_data     = s->data;
+    const bool s_is_f16     = s->type == GGML_TYPE_F16;
+
+    const size_t q_nb1 = q->nb[1];
+    const size_t q_nb2 = q->nb[2];
+    const size_t q_nb3 = q->nb[3];
+    const size_t k_nb1 = k->nb[1];
+    const size_t k_nb2 = k->nb[2];
+    const size_t k_nb3 = k->nb[3];
+    const size_t v_nb1 = v->nb[1];
+    const size_t v_nb2 = v->nb[2];
+    const size_t v_nb3 = v->nb[3];
+    const size_t g_nb0 = g->nb[0];
+    const size_t g_nb1 = g->nb[1];
+    const size_t g_nb2 = g->nb[2];
+    const size_t beta_nb0 = beta->nb[0];
+    const size_t beta_nb1 = beta->nb[1];
+    const size_t beta_nb2 = beta->nb[2];
+
+    std::vector<float> qn(D);
+    std::vector<float> kn(D);
+    std::vector<float> kv_mem(D);
+    std::vector<float> v_new(D);
+
+    for (int64_t bh = ih0; bh < ih1; ++bh) {
+        const int64_t b = bh / H;
+        const int64_t h = bh % H;
+
+        float * state = out_state + (bh * D * D);
+
+        // initialize state from input
+        if (s_is_f16) {
+            const ggml_fp16_t * s_src = (const ggml_fp16_t *) s_data + bh * D * D;
+            for (int64_t row = 0; row < D; ++row) {
+                for (int64_t col = 0; col < D; ++col) {
+                    state[row * D + col] = GGML_FP16_TO_FP32(s_src[row * D + col]);
+                }
+            }
+        } else {
+            const float * s_src = (const float *) s_data + bh * D * D;
+            memcpy(state, s_src, D * D * sizeof(float));
+        }
+
+        for (int64_t t = 0; t < T; ++t) {
+            const int64_t base_qkv = D * (h + H * (t + T * b));
+
+            const float * q_t = (const float *) (q_data + h*q_nb1 + t*q_nb2 + b*q_nb3);
+            const float * k_t = (const float *) (k_data + h*k_nb1 + t*k_nb2 + b*k_nb3);
+            const float * v_t = (const float *) (v_data + h*v_nb1 + t*v_nb2 + b*v_nb3);
+
+            // l2-norm(q), l2-norm(k)
+            float q_ss = 0.0f;
+            float k_ss = 0.0f;
+            for (int64_t d = 0; d < D; ++d) {
+                q_ss += q_t[d] * q_t[d];
+                k_ss += k_t[d] * k_t[d];
+            }
+            const float q_inv = 1.0f / sqrtf(q_ss + eps);
+            const float k_inv = 1.0f / sqrtf(k_ss + eps);
+            for (int64_t d = 0; d < D; ++d) {
+                qn[d] = q_t[d] * q_inv * q_scale;
+                kn[d] = k_t[d] * k_inv;
+            }
+
+            const float g_t    = *(const float *) (g_data + h*g_nb0 + t*g_nb1 + b*g_nb2);
+            const float beta_t = *(const float *) (beta_data + h*beta_nb0 + t*beta_nb1 + b*beta_nb2);
+            const float gexp   = expf(g_t);
+            const float b_sig  = ggml_compute_sigmoid_f32(beta_t);
+
+            // decay state + compute kv_mem = (k^T @ state) (per output dim)
+            std::fill(kv_mem.begin(), kv_mem.end(), 0.0f);
+            for (int64_t row = 0; row < D; ++row) {
+                const float k_row = kn[row];
+                float * state_row = state + row * D;
+                for (int64_t col = 0; col < D; ++col) {
+                    state_row[col] *= gexp;
+                    kv_mem[col] += state_row[col] * k_row;
+                }
+            }
+
+            // v_new = beta * (v - kv_mem)
+            for (int64_t col = 0; col < D; ++col) {
+                v_new[col] = b_sig * (v_t[col] - kv_mem[col]);
+            }
+
+            // state += k ⊗ v_new
+            for (int64_t row = 0; row < D; ++row) {
+                const float k_row = kn[row];
+                float * state_row = state + row * D;
+                for (int64_t col = 0; col < D; ++col) {
+                    state_row[col] += k_row * v_new[col];
+                }
+            }
+
+            // output = q^T @ state
+            float * out_t = out + base_qkv;
+            for (int64_t col = 0; col < D; ++col) {
+                float sum = 0.0f;
+                for (int64_t row = 0; row < D; ++row) {
+                    sum += qn[row] * state[row * D + col];
+                }
+                out_t[col] = sum;
+            }
+        }
+    }
+}
+
+
+void ggml_compute_forward_gated_delta_rule(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    ggml_compute_forward_gated_delta_rule_f32(params, dst);
+}
+
 static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];  // A (lower triangular)
     const struct ggml_tensor * src1 = dst->src[1];  // B (RHS)

+ 1 - 0
ggml/src/ggml-cpu/ops.h

@@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_gated_delta_rule(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);

+ 1 - 1
ggml/src/ggml-cuda/CMakeLists.txt

@@ -43,7 +43,7 @@ if (CUDAToolkit_FOUND)
     file(GLOB   GGML_HEADERS_CUDA "*.cuh")
     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
 
-    file(GLOB   GGML_SOURCES_CUDA "*.cu")
+    file(GLOB   GGML_SOURCES_CUDA CONFIGURE_DEPENDS "*.cu")
     file(GLOB   SRCS "template-instances/fattn-tile*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/fattn-mma*.cu")

+ 242 - 0
ggml/src/ggml-cuda/gated-delta-rule.cu

@@ -0,0 +1,242 @@
+#include "common.cuh"
+#include "gated-delta-rule.cuh"
+
+static __device__ __forceinline__ float sigmoid_f32(float x) {
+    if (x >= 0.0f) {
+        const float z = expf(-x);
+        return 1.0f / (1.0f + z);
+    } else {
+        const float z = expf(x);
+        return z / (1.0f + z);
+    }
+}
+
+template <typename T>
+static __device__ __forceinline__ float load_f32(const T * __restrict__ p) {
+    return (float) *p;
+}
+
+template <>
+__device__ __forceinline__ float load_f32<half>(const half * __restrict__ p) {
+    return __half2float(*p);
+}
+
+template <int K, int BV, typename T, typename S>
+static __global__ void gated_delta_rule_fwd(
+        const T     * __restrict__ q,       // [K, H, T, B]
+        const T     * __restrict__ k,       // [K, H, T, B]
+        const T     * __restrict__ v,       // [K, H, T, B]
+        const T     * __restrict__ g,       // [H, T, B]
+        const T     * __restrict__ beta,    // [H, T, B] (pre-sigmoid)
+        const S     * __restrict__ s,       // [K, K, H, B] (row-major: s[row][col])
+        float       * __restrict__ o,       // [K, H, T, B]
+        float       * __restrict__ st,      // [K, K, H, B]
+        const int                 H,
+        const int                 T_len,
+        const float               q_scale,
+        const float               eps,
+        const int64_t             q_nb1,
+        const int64_t             q_nb2,
+        const int64_t             q_nb3,
+        const int64_t             k_nb1,
+        const int64_t             k_nb2,
+        const int64_t             k_nb3,
+        const int64_t             v_nb1,
+        const int64_t             v_nb2,
+        const int64_t             v_nb3,
+        const int64_t             g_nb1,
+        const int64_t             g_nb2,
+        const int64_t             beta_nb1,
+        const int64_t             beta_nb2) {
+    static_assert(K % WARP_SIZE == 0, "K must be divisible by warp size");
+    static_assert(BV <= WARP_SIZE, "BV must be <= warp size");
+
+    const int lane = threadIdx.x;
+    const int v_tile = blockIdx.x;
+    const int bh = blockIdx.y;
+    const int b = bh / H;
+    const int h = bh - b * H;
+    const int v0 = v_tile * BV;
+
+    constexpr int rows_per_thread = K / WARP_SIZE;
+    float state[rows_per_thread][BV];
+
+    const int64_t s_base = (int64_t) bh * K * K;
+
+    // Load initial state
+    #pragma unroll
+    for (int rr = 0; rr < rows_per_thread; ++rr) {
+        const int row = lane + rr * WARP_SIZE;
+        #pragma unroll
+        for (int cc = 0; cc < BV; ++cc) {
+            const int col = v0 + cc;
+            state[rr][cc] = col < K ? load_f32(s + s_base + (int64_t) row * K + col) : 0.0f;
+        }
+    }
+
+    for (int t = 0; t < T_len; ++t) {
+        const int64_t q_base    = (int64_t) h * q_nb1 + (int64_t) t * q_nb2 + (int64_t) b * q_nb3;
+        const int64_t k_base    = (int64_t) h * k_nb1 + (int64_t) t * k_nb2 + (int64_t) b * k_nb3;
+        const int64_t v_base    = (int64_t) h * v_nb1 + (int64_t) t * v_nb2 + (int64_t) b * v_nb3;
+        const int64_t g_base    = (int64_t) h + (int64_t) t * g_nb1 + (int64_t) b * g_nb2;
+        const int64_t beta_base = (int64_t) h + (int64_t) t * beta_nb1 + (int64_t) b * beta_nb2;
+        const int64_t out_base  = (int64_t) K * (h + H * (t + T_len * b));
+
+        float q_raw[rows_per_thread];
+        float k_raw[rows_per_thread];
+        float q_ss = 0.0f;
+        float k_ss = 0.0f;
+
+        #pragma unroll
+        for (int rr = 0; rr < rows_per_thread; ++rr) {
+            const int idx = lane + rr * WARP_SIZE;
+            const float qv = load_f32(q + q_base + idx);
+            const float kv = load_f32(k + k_base + idx);
+            q_raw[rr] = qv;
+            k_raw[rr] = kv;
+            q_ss += qv * qv;
+            k_ss += kv * kv;
+        }
+
+        q_ss = warp_reduce_sum(q_ss);
+        k_ss = warp_reduce_sum(k_ss);
+
+        const float q_inv = rsqrtf(q_ss + eps);
+        const float k_inv = rsqrtf(k_ss + eps);
+
+        float qn[rows_per_thread];
+        float kn[rows_per_thread];
+
+        #pragma unroll
+        for (int rr = 0; rr < rows_per_thread; ++rr) {
+            qn[rr] = q_raw[rr] * q_inv * q_scale;
+            kn[rr] = k_raw[rr] * k_inv;
+        }
+
+        float gexp = 0.0f;
+        float bsig = 0.0f;
+        if (lane == 0) {
+            gexp = expf(load_f32(g + g_base));
+            bsig = sigmoid_f32(load_f32(beta + beta_base));
+        }
+        gexp = __shfl_sync(0xffffffff, gexp, 0);
+        bsig = __shfl_sync(0xffffffff, bsig, 0);
+
+        #pragma unroll
+        for (int cc = 0; cc < BV; ++cc) {
+            const int col = v0 + cc;
+            if (col >= K) continue;
+
+            float partial = 0.0f;
+            #pragma unroll
+            for (int rr = 0; rr < rows_per_thread; ++rr) {
+                state[rr][cc] *= gexp;
+                partial += state[rr][cc] * kn[rr];
+            }
+            const float dot_k = warp_reduce_sum(partial);
+
+            float v_in = 0.0f;
+            if (lane == cc) {
+                v_in = load_f32(v + v_base + col);
+            }
+            v_in = __shfl_sync(0xffffffff, v_in, cc);
+
+            const float v_new = bsig * (v_in - dot_k);
+
+            float partial_o = 0.0f;
+            #pragma unroll
+            for (int rr = 0; rr < rows_per_thread; ++rr) {
+                state[rr][cc] += kn[rr] * v_new;
+                partial_o += state[rr][cc] * qn[rr];
+            }
+            const float out = warp_reduce_sum(partial_o);
+            if (lane == cc) {
+                o[out_base + col] = out;
+            }
+        }
+    }
+
+    // Store final state
+    const int64_t st_base = (int64_t) bh * K * K;
+    #pragma unroll
+    for (int rr = 0; rr < rows_per_thread; ++rr) {
+        const int row = lane + rr * WARP_SIZE;
+        #pragma unroll
+        for (int cc = 0; cc < BV; ++cc) {
+            const int col = v0 + cc;
+            if (col < K) {
+                st[st_base + (int64_t) row * K + col] = state[rr][cc];
+            }
+        }
+    }
+}
+
+void ggml_cuda_op_gated_delta_rule(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * q    = dst->src[0];
+    const ggml_tensor * k    = dst->src[1];
+    const ggml_tensor * v    = dst->src[2];
+    const ggml_tensor * g    = dst->src[3];
+    const ggml_tensor * beta = dst->src[4];
+    const ggml_tensor * s    = dst->src[5];
+
+    const int K = (int) q->ne[0];
+    const int H = (int) q->ne[1];
+    const int T = (int) q->ne[2];
+    const int B = (int) q->ne[3];
+
+    const float q_scale = ggml_get_op_params_f32(dst, 0);
+    const float eps     = ggml_get_op_params_f32(dst, 1);
+
+    const size_t tsize = ggml_type_size(q->type);
+    const int64_t out_elems = (int64_t) K * H * T * B;
+
+    float * dst_d = (float *) dst->data;
+    float * o_d   = dst_d;
+    float * st_d  = dst_d + out_elems;
+
+    const int64_t q_nb1 = q->nb[1] / tsize;
+    const int64_t q_nb2 = q->nb[2] / tsize;
+    const int64_t q_nb3 = q->nb[3] / tsize;
+
+    const int64_t k_nb1 = k->nb[1] / tsize;
+    const int64_t k_nb2 = k->nb[2] / tsize;
+    const int64_t k_nb3 = k->nb[3] / tsize;
+
+    const int64_t v_nb1 = v->nb[1] / tsize;
+    const int64_t v_nb2 = v->nb[2] / tsize;
+    const int64_t v_nb3 = v->nb[3] / tsize;
+
+    const int64_t g_nb1 = g->nb[1] / tsize;
+    const int64_t g_nb2 = g->nb[2] / tsize;
+
+    const int64_t beta_nb1 = beta->nb[1] / tsize;
+    const int64_t beta_nb2 = beta->nb[2] / tsize;
+
+    constexpr int BV = 8;
+    const dim3 grid((K + BV - 1) / BV, (unsigned) (B * H), 1);
+    const dim3 block(WARP_SIZE, 1, 1);
+    cudaStream_t stream = ctx.stream();
+
+    // Use F32 implementation for everything (performing math in float)
+    if (q->type == GGML_TYPE_F16) {
+        if (s->type == GGML_TYPE_F16) {
+            if (K == 64)  gated_delta_rule_fwd<64,  BV, half, half><<<grid, block, 0, stream>>>( (const half *) q->data, (const half *) k->data, (const half *) v->data, (const half *) g->data, (const half *) beta->data, (const half *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else if (K == 128) gated_delta_rule_fwd<128, BV, half, half><<<grid, block, 0, stream>>>( (const half *) q->data, (const half *) k->data, (const half *) v->data, (const half *) g->data, (const half *) beta->data, (const half *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else GGML_ABORT("unsupported head dim");
+        } else {
+            if (K == 64)  gated_delta_rule_fwd<64,  BV, half, float><<<grid, block, 0, stream>>>( (const half *) q->data, (const half *) k->data, (const half *) v->data, (const half *) g->data, (const half *) beta->data, (const float *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else if (K == 128) gated_delta_rule_fwd<128, BV, half, float><<<grid, block, 0, stream>>>( (const half *) q->data, (const half *) k->data, (const half *) v->data, (const half *) g->data, (const half *) beta->data, (const float *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else GGML_ABORT("unsupported head dim");
+        }
+    } else {
+        if (s->type == GGML_TYPE_F16) {
+            if (K == 64)  gated_delta_rule_fwd<64,  BV, float, half><<<grid, block, 0, stream>>>( (const float *) q->data, (const float *) k->data, (const float *) v->data, (const float *) g->data, (const float *) beta->data, (const half *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else if (K == 128) gated_delta_rule_fwd<128, BV, float, half><<<grid, block, 0, stream>>>( (const float *) q->data, (const float *) k->data, (const float *) v->data, (const float *) g->data, (const float *) beta->data, (const half *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else GGML_ABORT("unsupported head dim");
+        } else {
+            if (K == 64)  gated_delta_rule_fwd<64,  BV, float, float><<<grid, block, 0, stream>>>( (const float *) q->data, (const float *) k->data, (const float *) v->data, (const float *) g->data, (const float *) beta->data, (const float *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else if (K == 128) gated_delta_rule_fwd<128, BV, float, float><<<grid, block, 0, stream>>>( (const float *) q->data, (const float *) k->data, (const float *) v->data, (const float *) g->data, (const float *) beta->data, (const float *) s->data, o_d, st_d, H, T, q_scale, eps, q_nb1, q_nb2, q_nb3, k_nb1, k_nb2, k_nb3, v_nb1, v_nb2, v_nb3, g_nb1, g_nb2, beta_nb1, beta_nb2);
+            else GGML_ABORT("unsupported head dim");
+        }
+    }
+}

+ 3 - 0
ggml/src/ggml-cuda/gated-delta-rule.cuh

@@ -0,0 +1,3 @@
+#pragma once
+#include "common.cuh"
+void ggml_cuda_op_gated_delta_rule(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 45 - 2
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -51,6 +51,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/gated-delta-rule.cuh"
 #include "ggml-cuda/set.cuh"
 #include "ggml-cuda/set-rows.cuh"
 #include "ggml-cuda/pad_reflect_1d.cuh"
@@ -2720,6 +2721,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_cuda_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_GATED_DELTA_RULE:
+            ggml_cuda_op_gated_delta_rule(ctx, dst);
+            break;
         case GGML_OP_RWKV_WKV7:
             ggml_cuda_op_rwkv_wkv7(ctx, dst);
             break;
@@ -3194,8 +3198,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         const ggml_tensor *tanh   = cgraph->nodes[node_idx+1];
         const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
 
-        GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
-        GGML_ASSERT(scale->type == GGML_TYPE_F32);
+        if (scale->src[0]->type != GGML_TYPE_F32 || scale->type != GGML_TYPE_F32) {
+            return false;
+        }
 
         if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
             return false;
@@ -4611,6 +4616,44 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_GROUP_NORM:
         case GGML_OP_PAD:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_GATED_DELTA_RULE: {
+            const ggml_tensor * q = op->src[0];
+            const ggml_tensor * k = op->src[1];
+            const ggml_tensor * v = op->src[2];
+            const ggml_tensor * g = op->src[3];
+            const ggml_tensor * beta = op->src[4];
+            const ggml_tensor * s = op->src[5];
+            const int64_t D = q->ne[0];
+            const ggml_type qtype = q->type;
+            const bool type_ok = (qtype == GGML_TYPE_F32 || qtype == GGML_TYPE_F16) &&
+                                 k->type == qtype &&
+                                 v->type == qtype &&
+                                 g->type == qtype &&
+                                 beta->type == qtype &&
+                                 (s->type == GGML_TYPE_F32 || s->type == GGML_TYPE_F16);
+            const size_t tsize = ggml_type_size(qtype);
+            const size_t ssize = ggml_type_size(s->type);
+            const bool stride_ok =
+                q->nb[0] == tsize && k->nb[0] == tsize && v->nb[0] == tsize &&
+                g->nb[0] == tsize && beta->nb[0] == tsize &&
+                q->nb[1] % tsize == 0 && q->nb[2] % tsize == 0 && q->nb[3] % tsize == 0 &&
+                k->nb[1] % tsize == 0 && k->nb[2] % tsize == 0 && k->nb[3] % tsize == 0 &&
+                v->nb[1] % tsize == 0 && v->nb[2] % tsize == 0 && v->nb[3] % tsize == 0 &&
+                g->nb[1] % tsize == 0 && g->nb[2] % tsize == 0 &&
+                beta->nb[1] % tsize == 0 && beta->nb[2] % tsize == 0 &&
+                s->nb[0] == ssize;
+            return type_ok &&
+                   stride_ok &&
+                   ggml_is_contiguous(s) &&
+                   ggml_are_same_shape(op->src[0], op->src[1]) &&
+                   ggml_are_same_shape(op->src[0], op->src[2]) &&
+                   ggml_is_3d(op->src[3]) &&
+                   ggml_is_3d(op->src[4]) &&
+                   op->src[3]->ne[0] == q->ne[1] && op->src[3]->ne[1] == q->ne[2] && op->src[3]->ne[2] == q->ne[3] &&
+                   op->src[4]->ne[0] == q->ne[1] && op->src[4]->ne[1] == q->ne[2] && op->src[4]->ne[2] == q->ne[3] &&
+                   s->ne[0] == D && s->ne[1] == D && s->ne[2] == q->ne[1] && s->ne[3] == q->ne[3] &&
+                   (D == 64 || D == 128);
+        }
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_ARANGE:

+ 1 - 3
ggml/src/ggml-cuda/scale.cu

@@ -18,8 +18,6 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
 
 void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -30,5 +28,5 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
     memcpy(&bias,  (float *) dst->op_params + 1, sizeof(float));
 
-    scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
+    scale_f32_cuda((const float *) src0->data, (float *) dst->data, scale, bias, ggml_nelements(src0), stream);
 }

+ 64 - 2
ggml/src/ggml.c

@@ -1026,6 +1026,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ADD_REL_POS",
     "RWKV_WKV6",
     "GATED_LINEAR_ATTN",
+    "GATED_DELTA_RULE",
     "RWKV_WKV7",
     "SOLVE_TRI",
 
@@ -1045,7 +1046,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1135,6 +1136,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
     "gated_linear_attn(k, v, q, gate, s)",
+    "gated_delta_rule(q, k, v, g, beta, s)",
     "rwkv_wkv7(r, w, k, v, a, b, s)",
     "A X = B, A triangular, solve X",
 
@@ -1154,7 +1156,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -5686,6 +5688,66 @@ struct ggml_tensor * ggml_gated_linear_attn(
     return result;
 }
 
+// ggml_gated_delta_rule
+
+struct ggml_tensor * ggml_gated_delta_rule(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        float                 scale,
+        float                 eps) {
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    GGML_ASSERT(q->type == k->type);
+    GGML_ASSERT(q->type == v->type);
+    GGML_ASSERT(q->type == g->type);
+    GGML_ASSERT(q->type == beta->type);
+    GGML_ASSERT(q->type == GGML_TYPE_F32 || q->type == GGML_TYPE_F16);
+    GGML_ASSERT(state->type == GGML_TYPE_F32 || state->type == GGML_TYPE_F16);
+
+    GGML_ASSERT(q->nb[0] == ggml_type_size(q->type));
+    GGML_ASSERT(k->nb[0] == ggml_type_size(k->type));
+    GGML_ASSERT(v->nb[0] == ggml_type_size(v->type));
+    GGML_ASSERT(g->nb[0] == ggml_type_size(g->type));
+    GGML_ASSERT(beta->nb[0] == ggml_type_size(beta->type));
+
+    const int64_t D = q->ne[0];
+    const int64_t H = q->ne[1];
+    const int64_t T = q->ne[2];
+    const int64_t B = q->ne[3];
+
+    GGML_ASSERT(ggml_are_same_shape(q, k));
+    GGML_ASSERT(ggml_are_same_shape(q, v));
+
+    GGML_ASSERT(ggml_is_3d(g));
+    GGML_ASSERT(g->ne[0] == H && g->ne[1] == T && g->ne[2] == B);
+
+    GGML_ASSERT(ggml_is_3d(beta));
+    GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == T && beta->ne[2] == B);
+
+    GGML_ASSERT(state->ne[0] == D && state->ne[1] == D && state->ne[2] == H && state->ne[3] == B);
+
+    // concatenated output + new_state
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(v) + ggml_nelements(state));
+
+    ggml_set_op_params_f32(result, 0, scale);
+    ggml_set_op_params_f32(result, 1, eps);
+
+    result->op     = GGML_OP_GATED_DELTA_RULE;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = g;
+    result->src[4] = beta;
+    result->src[5] = state;
+
+    return result;
+}
+
 // ggml_rwkv_wkv7
 
 struct ggml_tensor * ggml_rwkv_wkv7(

+ 5 - 2
src/llama-model.cpp

@@ -7128,6 +7128,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         };
                     }
 
+                    ggml_type recurrent_type_k = GGML_TYPE_F32;
+                    ggml_type recurrent_type_v = GGML_TYPE_F32;
+
                     res = new llama_memory_hybrid(
                         /* model             */ *this,
                         /* attn_type_k       */ params.type_k,
@@ -7137,8 +7140,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* attn_n_pad        */ 1,
                         /* attn_n_swa        */ hparams.n_swa,
                         /* attn_swa_type     */ hparams.swa_type,
-                        /* recurrent_type_k  */ GGML_TYPE_F32,
-                        /* recurrent_type_v  */ GGML_TYPE_F32,
+                        /* recurrent_type_k  */ recurrent_type_k,
+                        /* recurrent_type_v  */ recurrent_type_v,
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
                         /* n_seq_max         */ cparams.n_seq_max,
                         /* offload           */ cparams.offload_kqv,

+ 0 - 23
src/models/models.h

@@ -439,35 +439,12 @@ private:
     ggml_tensor * build_layer_attn_linear(
          llm_graph_input_rs * inp,
                 ggml_tensor * cur,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
                         int   il);
 
     ggml_tensor * build_layer_ffn(
                 ggml_tensor * cur,
                         int   il);
 
-    ggml_tensor * build_delta_net_chunking(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
-                        int   il);
-
-    ggml_tensor * build_delta_net_autoregressive(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                int           il);
 
     ggml_tensor * build_norm_gated(
                 ggml_tensor * input,

+ 29 - 384
src/models/qwen3next.cpp

@@ -16,17 +16,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_tensor * inp_pos     = build_inp_pos();
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    ggml_tensor * causal_mask =
-        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
-                    GGML_TRI_TYPE_LOWER);
-
-    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
-    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
-
-    ggml_build_forward_expand(gf, causal_mask);
-    ggml_build_forward_expand(gf, identity);
-    ggml_build_forward_expand(gf, diag_mask);
-
     for (int il = 0; il < n_layer; ++il) {
         ggml_tensor * inpSA = inpL;
 
@@ -36,7 +25,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
         // Determine layer type and build appropriate attention mechanism
         if (hparams.is_recurrent(il)) {
             // Linear attention layer (gated delta net)
-            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
         } else {
             // Full attention layer
             cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
@@ -86,345 +75,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_build_forward_expand(gf, cur);
 }
 
-ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
-
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
-
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(g, "g_pad", il);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
-
-    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
-
-    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
-
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-
-    cb(g_cumsum, "g_cumsum", il);
-
-    ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-
-    cb(decay_mask, "decay_mask", il);
-
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
-
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-
-    cb(attn, "attn_pre_solve", il);
-
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-
-    cb(attn, "attn_solved", il);
-
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
-
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
-
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-
-    cb(kbeta_gexp, "kbeta_gexp", il);
-
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-
-    cb(k_cumdecay, "k_cumdecay", il);
-
-    ggml_tensor * core_attn_out = nullptr;
-    ggml_tensor * new_state = ggml_dup(ctx0, state);
-
-    cb(new_state, "new_state", il);
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        auto chunkify = [=](ggml_tensor * t) {
-            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
-
-        auto chunkify_g = [=](ggml_tensor * t) {
-            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
-
-        ggml_tensor * k_chunk = chunkify(k);
-        ggml_tensor * q_chunk = chunkify(q);
-        ggml_tensor * v_chunk = chunkify(v);
-
-        ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
-        ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
-
-        ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
-        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
-
-        ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
-
-        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
-        attn = ggml_mul(ctx0, attn, decay_mask_chunk);
-        attn = ggml_mul(ctx0, attn, diag_mask);
-
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
-
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-
-        // v_new = v_i - v_prime
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
-
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-
-        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
-
-        // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-        // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-        // key_gdiff = key * g_diff.unsqueeze(-1)
-        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-        ggml_tensor * g_cum_last =
-            ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
-                                        g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
-                                        g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
-
-        ggml_tensor * gexp_last =
-            ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
-
-        ggml_tensor * g_cum_last_3d =
-            ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
-
-        ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
-
-        ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
-
-        ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-
-        ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
-                                        ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
-                                                        g_diff_exp->ne[2] * g_diff_exp->ne[3]));
-
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
-
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
-    }
-
-    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
-
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
-    cb(output_tokens, "output_tokens", il);
-
-    // flatten output
-    ggml_tensor * flat_output =
-        ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
-
-    ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
-
-    return ggml_concat(ctx0, flat_output, flat_state, 0);
-}
-
-ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
-
-    // Apply exponential to g_t
-    g_t = ggml_exp(ctx0, g_t);
-
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * g_t
-    state = ggml_mul(ctx0, state, g_t);
-
-    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
-    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
-    // we need to sum over dim=-2, so we transpose, sum, then transpose again
-    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
-
-    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
-    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    // delta = (v_t - kv_mem) * beta_t
-    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
-    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
-
-    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
-    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
-    state                   = ggml_add(ctx0, state, k_t_delta);
-
-    // Compute the attention output
-    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
-    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
-    // again, since it's over dim = -2, transpose, sum, transpose back
-    ggml_tensor * core_attn_out =
-        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
-
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
-
-    // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
-    ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
-    ggml_tensor * flat_state  = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
-
-    return ggml_concat(ctx0, flat_output, flat_state, 0);
-}
 
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
         ggml_tensor * input,
@@ -526,9 +176,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
 ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         llm_graph_input_rs * inp,
         ggml_tensor *        cur,
-        ggml_tensor *        causal_mask,
-        ggml_tensor *        identity,
-        ggml_tensor *        diag_mask,
         int                  il) {
     const auto * mctx_cur = inp->mctx;
 
@@ -645,9 +292,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
     cb(qkv_mixed, "qkv_mixed_permuted", il);
 
-    // Calculate the total conv dimension
-    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
-
     // Calculate convolution kernel size
     ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
@@ -674,37 +318,33 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(conv_states_all, "conv_states_updated", il);
 
     // Apply SSM convolution
-    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
-    cb(conv_output_proper, "conv_output_raw", il);
+    ggml_tensor * conv_output = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output, "conv_output_raw", il);
 
-    conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
-    cb(conv_output_proper, "conv_output_pre_silu", il);
-
-    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
     cb(conv_output_silu, "conv_output_silu", il);
 
-    ggml_tensor * conv_qkv_mix =
-        ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
-    cb(conv_qkv_mix, "conv_qkv_mix", il);
+    const size_t qkv_stride_t = conv_output_silu->nb[1];
+    const size_t qkv_stride_b = conv_output_silu->nb[2];
+    const size_t q_stride_h   = head_k_dim * ggml_element_size(conv_output_silu);
+    const size_t v_stride_h   = head_v_dim * ggml_element_size(conv_output_silu);
+    const size_t k_offset     = head_k_dim * num_k_heads * ggml_element_size(conv_output_silu);
+    const size_t v_offset     = 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output_silu);
 
-    // Extract the convolved Q, K, V from conv_output
+    // Extract the convolved Q, K, V directly as strided views (avoid extra copies).
     ggml_tensor * q_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
+        ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+                     q_stride_h, qkv_stride_t, qkv_stride_b, 0);
     cb(q_conv, "q_conv", il);
     ggml_tensor * k_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
-                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+        ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+                     q_stride_h, qkv_stride_t, qkv_stride_b, k_offset);
     cb(k_conv, "k_conv", il);
     ggml_tensor * v_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
-                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+        ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+                     v_stride_h, qkv_stride_t, qkv_stride_b, v_offset);
     cb(v_conv, "v_conv", il);
 
-    // Unsqueeze them
-    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
-
     beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
 
     ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
@@ -716,6 +356,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
 
+        q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+        k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+
         // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back
         ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
         ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
@@ -737,13 +380,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    ggml_tensor * attn_out;
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
-    }
+    // Fused gated delta rule (handles both prefill and decode)
+    const float q_scale = 1.0f / sqrtf((float) head_v_dim);
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    ggml_tensor * beta_3d  = ggml_reshape_3d(ctx0, beta, num_v_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * state_4d = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+
+
+    ggml_tensor * attn_out = ggml_gated_delta_rule(ctx0, q_conv, k_conv, v_conv, gate, beta_3d, state_4d, q_scale, eps_norm);
     cb(attn_out, "attn_out", il);
 
     // The tensors were concatenated 1d, so we need to extract them 1d as well

+ 54 - 0
tests/test-backend-ops.cpp

@@ -3516,6 +3516,55 @@ struct test_gla : public test_case {
     }
 };
 
+// GGML_OP_GATED_DELTA_RULE
+struct test_gated_delta_rule : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_dim;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+    const float eps;
+
+    ggml_tensor * t_g = nullptr;
+    ggml_tensor * t_state = nullptr;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_dim, n_seq_tokens, n_seqs);
+    }
+
+    test_gated_delta_rule(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 8, int64_t head_dim = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 4, float eps = 1e-6f)
+        : type(type), head_count(head_count), head_dim(head_dim), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * q = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_dim, head_count, n_seq_tokens, n_seqs }.data());
+        ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_dim, head_count, n_seq_tokens, n_seqs }.data());
+        ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_dim, head_count, n_seq_tokens, n_seqs }.data());
+        t_g = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_count, n_seq_tokens, n_seqs }.data());
+        ggml_tensor * beta = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_count, n_seq_tokens, n_seqs }.data());
+        t_state = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_dim, head_dim, head_count, n_seqs }.data());
+
+        ggml_tensor * out = ggml_gated_delta_rule(ctx, q, k, v, t_g, beta, t_state, powf((float) head_dim, -0.5f), eps);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+            if (t == t_g) {
+                // keep exp(g) bounded over long sequences
+                init_tensor_uniform(t, -0.2f, 0.0f);
+                continue;
+            }
+            if (t == t_state) {
+                init_tensor_uniform(t, -0.1f, 0.1f);
+                continue;
+            }
+            init_tensor_uniform(t);
+        }
+    }
+};
+
 // GGML_OP_RWKV_WKV7
 struct test_rwkv_wkv7 : public test_case {
     const ggml_type type;
@@ -7322,6 +7371,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
     test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
 
+    test_cases.emplace_back(new test_gated_delta_rule(GGML_TYPE_F32, 8, 64, 1, 1));
+    test_cases.emplace_back(new test_gated_delta_rule(GGML_TYPE_F32, 8, 64, 32, 1));
+    test_cases.emplace_back(new test_gated_delta_rule(GGML_TYPE_F32, 8, 64, 32, 4));
+    test_cases.emplace_back(new test_gated_delta_rule(GGML_TYPE_F32, 4, 128, 16, 2));
+
 #if 0
     // > 4GB A matrix. Too slow to be enabled by default.
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16,  900000,  3, 2592, {1, 1}, {1, 1}));