Explorar el Código

Rewrite recurrent delta + softmax to separate ops

Piotr Wilkin hace 3 meses
padre
commit
7eef0bd948

+ 35 - 0
ggml/include/ggml.h

@@ -544,6 +544,7 @@ extern "C" {
         GGML_OP_GATED_LINEAR_ATTN,
         GGML_OP_RWKV_WKV7,
         GGML_OP_DELTA_NET,
+        GGML_OP_DELTA_NET_RECURRENT,
 
         GGML_OP_UNARY,
 
@@ -578,6 +579,8 @@ extern "C" {
         GGML_UNARY_OP_HARDSWISH,
         GGML_UNARY_OP_HARDSIGMOID,
         GGML_UNARY_OP_EXP,
+        GGML_UNARY_OP_EXPM1,
+        GGML_UNARY_OP_SOFTPLUS,
         GGML_UNARY_OP_GELU_ERF,
 
         GGML_UNARY_OP_COUNT,
@@ -961,6 +964,22 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_expm1(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_expm1_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_softplus(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_softplus_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_sin(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -1164,6 +1183,22 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_expm1(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_expm1_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_softplus(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_softplus_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // gated linear unit ops
     // A: n columns, r rows,
     // result is n / 2 columns, r rows,

+ 7 - 0
ggml/src/ggml-cpu/ggml-cpu.c

@@ -2010,6 +2010,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_delta_net_f32(params, tensor);
             } break;
+        case GGML_OP_DELTA_NET_RECURRENT:
+            {
+                ggml_compute_forward_delta_net_recurrent_f32(params, tensor);
+            } break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -2193,6 +2197,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_HARDSIGMOID:
                 case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_SOFTPLUS:
+                case GGML_UNARY_OP_EXPM1:
                     {
                         n_tasks = 1;
                     } break;
@@ -2288,6 +2294,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_POOL_1D:
         case GGML_OP_POOL_2D:
         case GGML_OP_POOL_2D_BACK:
+        case GGML_OP_DELTA_NET_RECURRENT:
             {
                 n_tasks = 1;
             } break;

+ 202 - 0
ggml/src/ggml-cpu/ops.cpp

@@ -9861,6 +9861,14 @@ void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_exp(params, dst);
             } break;
+        case GGML_UNARY_OP_EXPM1:
+            {
+                ggml_compute_forward_expm1(params, dst);
+            } break;
+        case GGML_UNARY_OP_SOFTPLUS:
+            {
+                ggml_compute_forward_softplus(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -10874,6 +10882,200 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     }    
 }
 
+static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
+    GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n", 
+        name, token, data[0], data[1], data[2], data[3], data[4]);
+    double sum = 0.0;
+    for (unsigned int i = 0; i < size; i++) {
+        sum += data[i];
+    }
+    GGML_LOG_INFO("sum = %.10f\n", sum);
+}
+
+void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * params, ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];  // q_tokens
+    const struct ggml_tensor * src1 = dst->src[1];  // k_tokens
+    const struct ggml_tensor * src2 = dst->src[2];  // v_tokens
+    const struct ggml_tensor * src3 = dst->src[3];  // g_tokens_exp
+    const struct ggml_tensor * src4 = dst->src[4];  // beta_tokens
+    const struct ggml_tensor * src5 = dst->src[5];  // state
+    // src6, src7, src8 are nullptr in recurrent version
+
+    const int64_t H_v               = (int64_t) dst->op_params[0];
+    const int64_t S_k               = (int64_t) dst->op_params[1];
+    const int64_t S_v               = (int64_t) dst->op_params[2];
+    const int64_t original_n_tokens = (int64_t) dst->op_params[3];  // Get original sequence length
+    const int64_t n_tokens          = original_n_tokens;            // Use the original sequence length
+    const int64_t n_seqs            = src0->ne[3];                  // q tensor has n_seqs in dim 3
+
+    // Add assertions to verify tensor dimensions
+    GGML_ASSERT(src0->ne[3] == n_seqs);  // q tensor
+    GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
+    GGML_ASSERT(src2->ne[3] == n_seqs);  // v tensor
+    GGML_ASSERT(src3->ne[3] == n_seqs);  // g tensor
+    GGML_ASSERT(src4->ne[3] == n_seqs);  // beta tensor
+    GGML_ASSERT(src5->ne[3] == n_seqs);  // state tensor
+
+    float * dst_data  = (float *) dst->data;
+    // Output is first part, state is second part
+    float * output    = dst_data; // [S_v * H_v * n_tokens * n_seqs]
+    float * final_state = dst_data + (S_v * H_v * n_tokens * n_seqs);  // [S_v * S_v * H_v * n_seqs]
+
+    const int ith = params->ith;
+    // const int nth = params->nth;
+
+    // Clear output and new state section
+    if (ith == 0) {
+        memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
+    } else {
+        return; // only calculate on one thread
+    }
+
+    float * state_data = (float *) src5->data; // state is now src5
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(src2));
+    GGML_ASSERT(ggml_is_contiguous(src3));
+    GGML_ASSERT(ggml_is_contiguous(src4));
+    GGML_ASSERT(ggml_is_contiguous(src5));
+
+    const auto state_ptr = [state_data, src5] (int64_t seq, int64_t head, int64_t i, int64_t j) {
+        return state_data + (j * src5->nb[0] / sizeof(float)) + (i * src5->nb[1] / sizeof(float)) + 
+            (head * src5->nb[2] / sizeof(float)) + (seq * src5->nb[3] / sizeof(float));
+    };
+
+    // Process each token sequentially across all sequences and heads (recurrent processing)
+    // Following the PyTorch reference: for each token i, process all sequences and heads
+    for (int64_t token = 0; token < n_tokens; token++) {
+        const auto q_t = [token, src0] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src0, token, i, head, seq); };
+        const auto k_t = [token, src1] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src1, token, i, head, seq); };
+        const auto v_t = [token, src2] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src2, token, i, head, seq); };
+        const auto g_exp_t = [token, src3] (int64_t seq, int64_t head) { return ggml_get_f32_nd(src3, token, 0, head, seq); };
+        const auto beta_t = [token, src4] (int64_t seq, int64_t head) { return ggml_get_f32_nd(src4, token, 0, head, seq); };
+        
+        float * delta = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
+        float * kv_mem = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
+        float * attn_out_t = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
+        
+        // Create temporary arrays for processing all sequences and heads at once
+        float * temp_state = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
+        
+        // Initialize temp_state with current state values for all sequences and heads
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < S_v; j++) {
+                        int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
+                        temp_state[idx] = *(state_ptr(seq, head, i, j));
+                    }
+                }
+            }
+        }
+        print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
+
+        // 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float g_exp = g_exp_t(seq, head);
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < S_v; j++) {
+                        int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
+                        temp_state[idx] *= g_exp;
+                    }
+                }
+            }
+        }
+        print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
+        
+        // 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t j = 0; j < S_v; j++) {
+                    kv_mem[seq * H_v * S_v + head * S_v + j] = 0.0f;
+                    for (int64_t i = 0; i < S_v; i++) {
+                        int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
+                        // This implements: (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
+                        kv_mem[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * k_t(seq, head, i);
+                    }
+                }
+            }
+        }
+        print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
+        
+        // 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float beta_val = beta_t(seq, head);
+                for (int64_t j = 0; j < S_v; j++) {
+                    delta[seq * H_v * S_v + head * S_v + j] =
+                        (v_t(seq, head, j) - kv_mem[seq * H_v * S_v + head * S_v + j]) * beta_val;
+                }
+            }
+        }
+        print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
+        
+        // 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < S_v; j++) {
+                        int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
+                        // k_t[i] * delta[j] (where delta is treated as column vector)
+                        temp_state[state_idx] += k_t(seq, head, i) * delta[seq * H_v * S_v + head * S_v + j];
+                    }
+                }
+            }
+        }
+        print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
+        
+        // 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t j = 0; j < S_v; j++) {
+                    attn_out_t[seq * H_v * S_v + head * S_v + j] = 0.0f;
+                    for (int64_t i = 0; i < S_v; i++) {
+                        int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
+                        attn_out_t[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * q_t(seq, head, i);
+                    }
+                }
+            }
+        }
+        print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
+        
+        // Store the output for this token (for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t d = 0; d < S_v; d++) {
+                    int64_t output_idx = d + head * S_v + token * (S_v * H_v) + seq * (S_v * H_v * n_tokens);
+                    output[output_idx] = attn_out_t[seq * H_v * S_v + head * S_v + d];
+                }
+            }
+        }
+        
+        // Update the working state for next token iteration (in the state tensor for all seqs and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < S_v; j++) {
+                        int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
+                        *(state_ptr(seq, head, i, j)) = temp_state[state_idx];
+                        
+                        // Store the final state for this head and sequence (for output)
+                        int64_t final_state_idx = i + j * S_v + head * (S_v * S_v) + seq * (S_v * S_v * H_v);
+                        final_state[final_state_idx] = temp_state[state_idx];
+                    }
+                }
+            }
+        }
+        
+        free(temp_state);
+        free(delta);
+        free(kv_mem);
+        free(attn_out_t);
+    }
+}
+
 // ggml_compute_forward_rwkv_wkv7
 static void ggml_compute_forward_rwkv_wkv7_f32(
         const ggml_compute_params * params,

+ 1 - 0
ggml/src/ggml-cpu/ops.h

@@ -103,6 +103,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
 void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_delta_net_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_delta_net_recurrent_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);

+ 16 - 0
ggml/src/ggml-cpu/unary-ops.cpp

@@ -64,6 +64,14 @@ static inline float op_log(float x) {
     return logf(x);
 }
 
+static inline float op_expm1(float x) {
+    return expf(x) - 1.0f;
+}
+
+static inline float op_softplus(float x) {
+    return (x > 20.0f) ? x : logf(1.0f + expf(x));
+}
+
 template <float (*op)(float), typename src0_t, typename dst_t>
 static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
     constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -184,3 +192,11 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor *
 void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
     unary_op<op_log>(params, dst);
 }
+
+void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op<op_expm1>(params, dst);
+}
+
+void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
+    unary_op<op_softplus>(params, dst);
+}

+ 2 - 0
ggml/src/ggml-cpu/unary-ops.h

@@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
 void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
 #ifdef __cplusplus
 }

+ 8 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -2333,6 +2333,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_ELU:
                     ggml_cuda_op_elu(ctx, dst);
                     break;
+                case GGML_UNARY_OP_EXPM1:
+                    ggml_cuda_op_expm1(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SOFTPLUS:
+                    ggml_cuda_op_softplus(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -3314,6 +3320,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_EXPM1:
+                case GGML_UNARY_OP_SOFTPLUS:
                 case GGML_UNARY_OP_ELU:
                     return ggml_is_contiguous(op->src[0]);
                 default:

+ 16 - 0
ggml/src/ggml-cuda/unary.cu

@@ -83,6 +83,14 @@ static __device__ __forceinline__ float op_log(float x) {
     return logf(x);
 }
 
+static __device__ __forceinline__ float op_expm1(float x) {
+    return expf(x) - 1.0f;
+}
+
+static __device__ __forceinline__ float op_softplus(float x) {
+    return (x > 20.0f) ? x : logf(1.0f + expf(x));
+}
+
 static __device__ __forceinline__ float op_elu(float x) {
     return (x > 0.f) ? x : expm1f(x);
 }
@@ -203,6 +211,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_unary<op_elu>(ctx, dst);
 }
+
+void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_expm1>(ctx, dst);
+}
+
+void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_softplus>(ctx, dst);
+}
 /* gated ops */
 
 template <float (*op)(float), typename T>

+ 4 - 0
ggml/src/ggml-cuda/unary.cuh

@@ -59,6 +59,10 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 65 - 12
ggml/src/ggml.c

@@ -1005,6 +1005,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GATED_LINEAR_ATTN",
     "RWKV_WKV7",
     "DELTA_NET",
+    "DELTA_NET_RECURRENT",
 
     "UNARY",
 
@@ -1022,7 +1023,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93");
+static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1112,6 +1113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "gated_linear_attn(k, v, q, gate, s)",
     "rwkv_wkv7(r, w, k, v, a, b, s)",
     "delta_net(q, k, v, g, beta, state)",
+    "delta_net_recurrent(q, k, v, g, beta, state)",
 
     "unary(x)",
 
@@ -1129,7 +1131,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93");
+static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1148,10 +1150,12 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "HARDSWISH",
     "HARDSIGMOID",
     "EXP",
+    "EXPM1",
+    "SOFTPLUS",
     "GELU_ERF",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
+static_assert(GGML_UNARY_OP_COUNT == 17, "GGML_UNARY_OP_COUNT != 17");
 
 
 static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
@@ -2260,6 +2264,30 @@ struct ggml_tensor * ggml_log_inplace(
     return ggml_log_impl(ctx, a, true);
 }
 
+struct ggml_tensor * ggml_expm1(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);
+}
+
+struct ggml_tensor * ggml_expm1_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);
+}
+
+struct ggml_tensor * ggml_softplus(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
+}
+
+struct ggml_tensor * ggml_softplus_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
+}
+
 // ggml_sin
 
 static struct ggml_tensor * ggml_sin_impl(
@@ -6402,16 +6430,41 @@ static void ggml_compute_backward(
                         ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
                     }
                 } break;
-                case GGML_UNARY_OP_EXP: {
-                    if (src0_needs_grads) {
-                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
+                case GGML_UNARY_OP_EXP:
+                    {
+                        if (src0_needs_grads) {
+                            ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
+                        }
                     }
-                } break;
-                default: {
-                    fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
-                        __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
-                    GGML_ABORT("fatal error");
-                } //break;
+                    break;
+                case GGML_UNARY_OP_EXPM1:
+                    {
+                        if (src0_needs_grads) {
+                            ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));
+                        }
+                    }
+                    break;
+                case GGML_UNARY_OP_SOFTPLUS:
+                    {
+                        if (src0_needs_grads) {
+                            // gradient of softplus: sigmoid(x) = 1 / (1 + exp(-x))
+                            struct ggml_tensor * neg_src0 = ggml_neg(ctx, src0);
+                            struct ggml_tensor * exp_neg  = ggml_exp(ctx, neg_src0);
+                            struct ggml_tensor * ones =
+                                ggml_exp(ctx, ggml_new_tensor_4d(ctx, src0->type, src0->ne[0], src0->ne[1], src0->ne[2],
+                                                                 src0->ne[3]));
+                            struct ggml_tensor * one_plus_exp = ggml_add(ctx, ones, exp_neg);
+                            struct ggml_tensor * sigmoid      = ggml_div(ctx, ones, one_plus_exp);
+                            ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, sigmoid));
+                        }
+                    }
+                    break;
+                default:
+                    {
+                        fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", __func__,
+                                ggml_unary_op_name(ggml_get_unary_op(tensor)));
+                        GGML_ABORT("fatal error");
+                    }  //break;
             }
         } break;
         case GGML_OP_CROSS_ENTROPY_LOSS: {

+ 27 - 80
src/models/llm_build_qwen3next.cpp

@@ -361,8 +361,8 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     cb(attn, "attn_in", il);
 
     // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
-    const int64_t ne[1] = { (S_v * H_v * n_tokens * n_seqs ) + (S_v * S_v * H_v * n_seqs) };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
+    const int64_t total_dims = (S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs);
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, total_dims);
 
     ggml_set_op_params_i32(result, 0, H_v);
     ggml_set_op_params_i32(result, 1, S_k);
@@ -384,7 +384,6 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
 }
 
 // delta_net_recurrent
-// Recurrent version of delta_net for sequence_length = 1
 struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
         struct ggml_context * ctx,
         struct ggml_tensor  * q,
@@ -467,79 +466,33 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
     state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
     ggml_tensor * g_tokens_exp = ggml_exp(ctx, g_tokens);
 
-    ggml_tensor * final_output = nullptr;
-    ggml_tensor * q_t, * k_t, * v_t, * g_t_exp, * beta_t;
-    for (int i = 0; i < n_tokens; i++) { // this part is per token
-        if (n_tokens == 1) { // don't do unnecessary reshapes / views
-            q_t = q_tokens;
-            k_t = k_tokens;
-            v_t = v_tokens;
-            g_t_exp = g_tokens_exp;
-            beta_t = beta_tokens;
-        } else {
-            q_t = ggml_view_4d(ctx, q_tokens, 1, S_k, H_k, n_seqs, q_tokens->nb[1], q_tokens->nb[2], q_tokens->nb[3], i * ggml_element_size(q_tokens));
-            k_t = ggml_view_4d(ctx, k_tokens, 1, S_k, H_k, n_seqs, k_tokens->nb[1], k_tokens->nb[2], k_tokens->nb[3], i * ggml_element_size(k_tokens));
-            v_t = ggml_view_4d(ctx, v_tokens, 1, S_v, H_k, n_seqs, v_tokens->nb[1], v_tokens->nb[2], v_tokens->nb[3], i * ggml_element_size(v_tokens));
-            g_t_exp = ggml_view_4d(ctx, g_tokens_exp, 1, 1, H_k, n_seqs, g_tokens_exp->nb[1], g_tokens_exp->nb[2], g_tokens_exp->nb[3], i * ggml_element_size(g_tokens_exp));
-            beta_t = ggml_view_4d(ctx, beta_tokens, 1, 1, H_k, n_seqs, beta_tokens->nb[1], beta_tokens->nb[2], beta_tokens->nb[3], i * ggml_element_size(beta_tokens));
-        }
-
-        // Apply gate to state: state = state * exp(g)
-        ggml_tensor * gated_state = ggml_mul(ctx, state, g_t_exp);
-        cb(gated_state, "gated_state", il);
-
-        // Compute kv_memory from state and key
-        // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
-        
-        // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
-        // to make it compatible with k_expanded for element-wise multiplication
-        ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
-        cb(gated_state_reshaped, "gated_state_reshaped", il);
-        
-        ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
-        cb(state_k_product, "state_k_product", il);
-
-        ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
-        cb(kv_memory, "kv_memory", il);
-
-        // Compute delta = (v - kv_memory) * beta
-        ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
-        ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
-        cb(delta, "delta", il);
-
-        // Update state = state + k * delta
-        // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
-        ggml_tensor * delta_t = ggml_transpose(ctx, delta);
-
-        // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
-        ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
-        ggml_tensor * k_t_broadcast  = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
-        ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
-        cb(k_delta_product, "k_delta", il);
+    // Create result tensor with the same dimensions as delta_net
+    const int64_t total_dims = (S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs);
+    ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, total_dims);
 
-        state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
-        cb(state, "updated_state", il);
+    cb(q_tokens, "q_tokens", il);
+    cb(k_tokens, "k_tokens", il);
+    cb(v_tokens, "v_tokens", il);
+    cb(g_tokens, "g_tokens", il);
+    cb(beta_tokens, "beta_tokens", il);
+    cb(g_tokens_exp, "g_tokens_exp", il);
+    cb(state, "state_pre", il);
 
-        ggml_tensor * state_q_product = ggml_mul(ctx, state, q_t);
-        cb(state_q_product, "state_q_product", il);
-        
-        ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
-        cb(output, "output", il);
+    // Set operation parameters
+    ggml_set_op_params_i32(result, 0, H_v);
+    ggml_set_op_params_i32(result, 1, S_k);
+    ggml_set_op_params_i32(result, 2, S_v);
+    ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
 
-        if (final_output == nullptr) {
-            final_output = output;
-        } else {
-            final_output = ggml_concat(ctx, final_output, output, 0);
-        }
-    }
+    // Set operation and source tensors
+    result->op     = GGML_OP_DELTA_NET_RECURRENT;
+    result->src[0] = q_tokens;
+    result->src[1] = k_tokens;
+    result->src[2] = v_tokens;
+    result->src[3] = g_tokens_exp;
+    result->src[4] = beta_tokens;
+    result->src[5] = state;
     
-    // Concatenate output and updated_state into a single tensor
-    // First, flatten both tensors to 1D
-    ggml_tensor * output_1d = ggml_cont_1d(ctx, final_output, ggml_nelements(final_output));
-    ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, state, ggml_nelements(state));
-    
-    // Concatenate them: [output, updated_state]
-    ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);
     return result;
 }
 
@@ -604,7 +557,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
 
-    ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
     cb(alpha_softplus, "a_softplus", il);
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
@@ -870,10 +824,3 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
     return cur;
 }
 
-ggml_tensor * llm_build_qwen3next::softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
-    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, dt_bias);                // a + dt_bias
-    ggml_tensor * alpha_exp      = ggml_exp(ctx0, alpha_biased);                  // exp(a + dt_bias)
-    ggml_tensor * one_plus_exp   = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f);  // 1 + exp(a + dt_bias)
-    ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp);                  // log(1 + exp(...))
-    return alpha_softplus;
-}

+ 0 - 2
src/models/llm_build_qwen3next.h

@@ -51,8 +51,6 @@ private:
 
     ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il);
 
-    ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);
-
     ggml_tensor * build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer);
     ggml_tensor * build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer);
 

+ 152 - 0
tests/test-backend-ops.cpp

@@ -3610,6 +3610,150 @@ struct test_cos : public test_case {
     }
 };
 
+// GGML_OP_EXPM1
+struct test_expm1 : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_expm1(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 3, 3, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_expm1(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // Use small values to avoid overflow in expm1
+            init_tensor_uniform(t, -2.0f, 2.0f);
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_SOFTPLUS
+struct test_softplus : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_softplus(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 3, 3, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_softplus(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // Use values around the threshold (20) to test both branches of softplus
+            init_tensor_uniform(t, -25.0f, 25.0f);
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_EXPM1_INPLACE
+struct test_expm1_inplace : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_expm1_inplace(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 3, 3, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_expm1_inplace(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // Use small values to avoid overflow in expm1
+            init_tensor_uniform(t, -2.0f, 2.0f);
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_SOFTPLUS_INPLACE
+struct test_softplus_inplace : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_softplus_inplace(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 3, 3, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_softplus_inplace(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // Use values around the threshold (20) to test both branches of softplus
+            init_tensor_uniform(t, -25.0f, 25.0f);
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 // GGML_OP_CLAMP
 struct test_clamp : public test_case {
     const ggml_type type;
@@ -6332,6 +6476,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_sqr       (type));
         test_cases.emplace_back(new test_sqrt      (type));
         test_cases.emplace_back(new test_log       (type));
+        test_cases.emplace_back(new test_expm1     (type));
+        test_cases.emplace_back(new test_softplus  (type));
+        test_cases.emplace_back(new test_expm1_inplace     (type));
+        test_cases.emplace_back(new test_softplus_inplace  (type));
         test_cases.emplace_back(new test_sin       (type));
         test_cases.emplace_back(new test_cos       (type));
         test_cases.emplace_back(new test_clamp     (type));
@@ -6339,6 +6487,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_sqr       (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_sqrt      (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_log       (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_expm1     (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_softplus  (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_expm1_inplace     (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_softplus_inplace  (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_sin       (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_cos       (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_clamp     (type, {7, 1, 5, 3}));