Explorar o código

Add CUDA kernels for TRI, CUMSUM, and DELTA_NET operations

- Add cumsum.cu/cuh: CUDA implementation for cumulative sum operation
- Add delta-net.cu/cuh: CUDA implementation for delta network operations (recurrent and chunked)
- Add tri.cu/cuh: CUDA implementation for triangular matrix operations
- Update ggml-cuda.cu to register new operations
- Update CPU ops.cpp with corresponding CPU implementations
cturan hai 2 meses
pai
achega
2e9f9bc889

+ 143 - 0
ggml/src/ggml-cuda/cumsum.cu

@@ -0,0 +1,143 @@
+#include "cumsum.cuh"
+
+// Warp-level inclusive scan (cumulative sum)
+template<int WARP_SZ>
+__device__ inline float warp_cumsum(float val, int lane_id) {
+    #pragma unroll
+    for (int offset = 1; offset < WARP_SZ; offset *= 2) {
+        float n = __shfl_up_sync(0xffffffff, val, offset);
+        if (lane_id >= offset) val += n;
+    }
+    return val;
+}
+
+// Kernel for small rows (row_len <= 1024)
+// Each block processes one row
+template<int BLOCK_SIZE>
+__global__ void cumsum_f32_kernel(const float * __restrict__ x, float * __restrict__ dst, 
+                                   int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3,
+                                   int64_t nb0, int64_t nb1, int64_t nb2, int64_t nb3,
+                                   int64_t dst_nb0, int64_t dst_nb1, int64_t dst_nb2, int64_t dst_nb3) {
+    
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+    
+    if (i3 >= ne3 || i2 >= ne2 || i1 >= ne1) return;
+    
+    const float * src_row = (const float *)((const char *)x + i1*nb1 + i2*nb2 + i3*nb3);
+    float * dst_row = (float *)((char *)dst + i1*dst_nb1 + i2*dst_nb2 + i3*dst_nb3);
+    
+    const int tid = threadIdx.x;
+    const int lane_id = tid & 31;
+    const int warp_id = tid / 32;
+    const int num_warps = BLOCK_SIZE / 32;
+    
+    __shared__ float warp_sums[32]; // max 32 warps per block
+    
+    // Use register for carry instead of shared memory - faster!
+    float carry_accum = 0.0f;
+    
+    // Process elements in chunks of BLOCK_SIZE
+    for (int64_t i = tid; i < ne0; i += BLOCK_SIZE) {
+        float val = src_row[i];
+        
+        // Warp-level scan
+        float warp_sum = warp_cumsum<32>(val, lane_id);
+        
+        // Get the total sum from this warp and broadcast to all warps
+        if (lane_id == 31) {
+            warp_sums[warp_id] = warp_sum;
+        }
+        __syncthreads();
+        
+        // Thread 0 computes prefix sum of warp totals
+        __shared__ float tile_carry;
+        if (tid == 0) {
+            float s = 0.0f;
+            for (int w = 0; w < num_warps; w++) {
+                float tmp = warp_sums[w];
+                warp_sums[w] = s; // warp prefix offset within this tile
+                s += tmp;         // accumulate total of this tile
+            }
+            tile_carry = carry_accum; // carry to add to this tile's results
+            carry_accum += s;         // update carry for next tile (register!)
+        }
+        __syncthreads();
+
+        // Add warp prefix and previous tile carry
+        float result = warp_sum + warp_sums[warp_id] + tile_carry;
+        dst_row[i] = result;
+    }
+}
+
+// Fallback for very large rows: sequential processing
+__global__ void cumsum_f32_sequential_kernel(const float * __restrict__ x, float * __restrict__ dst,
+                                              int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3,
+                                              int64_t nb0, int64_t nb1, int64_t nb2, int64_t nb3,
+                                              int64_t dst_nb0, int64_t dst_nb1, int64_t dst_nb2, int64_t dst_nb3) {
+    
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+    
+    if (i3 >= ne3 || i2 >= ne2 || i1 >= ne1) return;
+    if (threadIdx.x != 0) return; // Only first thread in block
+    
+    const float * src_row = (const float *)((const char *)x + i1*nb1 + i2*nb2 + i3*nb3);
+    float * dst_row = (float *)((char *)dst + i1*dst_nb1 + i2*dst_nb2 + i3*dst_nb3);
+    
+    float cumsum = 0.0f;
+    for (int64_t i = 0; i < ne0; i++) {
+        cumsum += src_row[i];
+        dst_row[i] = cumsum;
+    }
+}
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+    
+    cudaStream_t stream = ctx.stream();
+    
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
+    
+    const int64_t nb0 = src0->nb[0];
+    const int64_t nb1 = src0->nb[1];
+    const int64_t nb2 = src0->nb[2];
+    const int64_t nb3 = src0->nb[3];
+    
+    const int64_t dst_nb0 = dst->nb[0];
+    const int64_t dst_nb1 = dst->nb[1];
+    const int64_t dst_nb2 = dst->nb[2];
+    const int64_t dst_nb3 = dst->nb[3];
+    
+    // Launch kernel
+    dim3 grid(ne1, ne2, ne3);
+    
+    if (ne0 <= 4096) {
+        // Use parallel scan for small to medium rows
+        const int block_size = 256;
+        cumsum_f32_kernel<block_size><<<grid, block_size, 0, stream>>>(
+            src0_d, dst_d, ne0, ne1, ne2, ne3,
+            nb0, nb1, nb2, nb3,
+            dst_nb0, dst_nb1, dst_nb2, dst_nb3
+        );
+    } else {
+        // Use sequential kernel for very large rows
+        cumsum_f32_sequential_kernel<<<grid, 1, 0, stream>>>(
+            src0_d, dst_d, ne0, ne1, ne2, ne3,
+            nb0, nb1, nb2, nb3,
+            dst_nb0, dst_nb1, dst_nb2, dst_nb3
+        );
+    }
+}
+

+ 4 - 0
ggml/src/ggml-cuda/cumsum.cuh

@@ -0,0 +1,4 @@
+#include "common.cuh"
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+

+ 561 - 0
ggml/src/ggml-cuda/delta-net.cu

@@ -0,0 +1,561 @@
+#include "delta-net.cuh"
+
+// Configure a reasonable block size. We use 256 threads (16x16) for 2D tiling when needed.
+#define DELTA_NET_BLOCK_SIZE 16
+#define T 256  // Number of threads per block (x-dimension)
+
+#if !defined(LDG)
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
+#define LDG(ptr) __ldg(ptr)
+#else
+#define LDG(ptr) (*(ptr))
+#endif
+#endif
+
+#if !defined(FMA)
+#define FMA(a,b,c) fmaf((a),(b),(c))
+#endif
+
+#ifndef GGML_DELTA_NET_CHUNK
+#define GGML_DELTA_NET_CHUNK 64
+#endif
+
+// DELTA_NET_RECURRENT kernel
+// Each block processes one (sequence, head) pair
+// Token loop is sequential due to state dependency
+__global__ void delta_net_recurrent_f32_kernel(
+    const float * __restrict__ q_tokens,      // [n_tokens, S_v, H_v, n_seqs]
+    const float * __restrict__ k_tokens,      // [n_tokens, S_v, H_v, n_seqs]
+    const float * __restrict__ v_tokens,      // [n_tokens, S_v, H_v, n_seqs]
+    const float * __restrict__ g_tokens_exp,  // [n_tokens, 1, H_v, n_seqs]
+    const float * __restrict__ beta_tokens,   // [n_tokens, 1, H_v, n_seqs]
+    const float * __restrict__ state_in,      // [S_v, S_v, H_v, n_seqs]
+    float * __restrict__ output,              // [S_v, H_v, n_tokens, n_seqs]
+    float * __restrict__ state_out,           // [S_v, S_v, H_v, n_seqs]
+    int64_t S_v,
+    int64_t H_v,
+    int64_t n_tokens,
+    int64_t n_seqs) {
+    
+    const int head = blockIdx.x;
+    const int seq = blockIdx.y;
+    
+    if (head >= H_v || seq >= n_seqs) return;
+    
+    const int tid = threadIdx.x;
+
+    // Dynamic shared memory: only vectors and a couple of scalars
+    extern __shared__ float smem[];
+    float * q_vec   = smem;             // S_v
+    float * k_vec   = q_vec   + S_v;    // S_v
+    float * v_vec   = k_vec   + S_v;    // S_v
+    float * kv_mem  = v_vec   + S_v;    // S_v
+    float * delta   = kv_mem  + S_v;    // S_v
+    float * out_vec = delta   + S_v;    // S_v
+    float * scalars = out_vec + S_v;    // 2 floats: [0]=g_exp, [1]=beta
+
+    // Offset helper matching CPU layout: [seq][head][i][j]
+    // CPU: idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j
+    // Tensor layout is actually: [j, i, head, seq] based on nb strides
+    const size_t state_base = (size_t)head * (size_t)(S_v * S_v) + (size_t)seq * (size_t)(S_v * S_v * H_v);
+    auto off_state = [=](int i, int j) -> size_t {
+        return (size_t)j + (size_t)i * S_v + state_base;
+    };
+    
+    auto off_tok_vec = [=](int token, int d) -> size_t {
+        // Matching CPU: ggml_get_f32_nd(src0, token, d, head, seq)
+        // Layout: [token, d, head, seq]
+        return (size_t)token + (size_t)d * n_tokens + (size_t)head * (n_tokens * S_v) + (size_t)seq * (n_tokens * S_v * H_v);
+    };
+    
+    auto off_scalar_tok = [=](const float * base, int token) -> float {
+        // g_exp / beta: ggml_get_f32_nd(src, token, 0, head, seq)
+        // Layout: [token, 1, head, seq]
+        return base[(size_t)token + (size_t)head * n_tokens + (size_t)seq * (n_tokens * H_v)];
+    };
+
+    // Initialize state_out with state_in
+    for (int idx = tid; idx < S_v * S_v; idx += blockDim.x) {
+        int i = idx / S_v;
+        int j = idx % S_v;
+        state_out[off_state(i, j)] = state_in[off_state(i, j)];
+    }
+    __syncthreads();
+    
+    // Process each token sequentially
+    for (int token = 0; token < n_tokens; token++) {
+        // Load q, k, v for this token (handle S_v > blockDim.x)
+        for (int d = tid; d < S_v; d += blockDim.x) {
+            q_vec[d] = LDG(&q_tokens[off_tok_vec(token, d)]);
+            k_vec[d] = LDG(&k_tokens[off_tok_vec(token, d)]);
+            v_vec[d] = LDG(&v_tokens[off_tok_vec(token, d)]);
+        }
+        // Load scalars
+        if (tid == 0) {
+            scalars[0] = off_scalar_tok(g_tokens_exp, token);
+            scalars[1] = off_scalar_tok(beta_tokens, token);
+        }
+        __syncthreads();
+        float g_exp = scalars[0];
+        float beta_val = scalars[1];
+
+        // 1. state = state * g_exp (element-wise multiplication)
+        // CPU: temp_state[idx] *= g_exp;
+        for (int idx = tid; idx < S_v * S_v; idx += blockDim.x) {
+            int i = idx / S_v;
+            int j = idx % S_v;
+            state_out[off_state(i, j)] *= g_exp;
+        }
+        __syncthreads();
+        
+        // 2. kv_mem[j] = sum_i (state[i,j] * k[i])
+        // CPU: kv_mem[j] += temp_state[state_idx] * k_t(i)
+        for (int j = tid; j < S_v; j += blockDim.x) {
+            float sum = 0.0f;
+            size_t sidx = state_base + (size_t)j;
+            #pragma unroll 4
+            for (int i = 0; i < S_v; i++) {
+                sum = FMA(state_out[sidx], k_vec[i], sum);
+                sidx += (size_t)S_v;
+            }
+            kv_mem[j] = sum;
+        }
+        __syncthreads();
+        
+        // 3. delta = (v - kv_mem) * beta
+        // CPU: delta[j] = (v_t(j) - kv_mem[j]) * beta_val
+        for (int j = tid; j < S_v; j += blockDim.x) {
+            delta[j] = (v_vec[j] - kv_mem[j]) * beta_val;
+        }
+        __syncthreads();
+        
+        // 4. state[i,j] += k[i] * delta[j]  (outer product)
+        // CPU: temp_state[state_idx] += k_t(i) * delta[j]
+        for (int idx = tid; idx < S_v * S_v; idx += blockDim.x) {
+            int i = idx / S_v;
+            int j = idx % S_v;
+            size_t sidx = state_base + (size_t)j + (size_t)i * (size_t)S_v;
+            state_out[sidx] = FMA(k_vec[i], delta[j], state_out[sidx]);
+        }
+        __syncthreads();
+        
+        // 5. output[j] = sum_i (state[i,j] * q[i])
+        // CPU: attn_out_t[j] += temp_state[state_idx] * q_t(i)
+        for (int j = tid; j < S_v; j += blockDim.x) {
+            float sum = 0.0f;
+            size_t sidx = state_base + (size_t)j;
+            #pragma unroll 4
+            for (int i = 0; i < S_v; i++) {
+                sum = FMA(state_out[sidx], q_vec[i], sum);
+                sidx += (size_t)S_v;
+            }
+            out_vec[j] = sum;
+        }
+        __syncthreads();
+        
+        // Store output for this token
+        // CPU output layout: d + head * S_v + token * (S_v * H_v) + seq * (S_v * H_v * n_tokens)
+        for (int d = tid; d < S_v; d += blockDim.x) {
+            size_t output_idx = (size_t)d + (size_t)head * S_v + (size_t)token * (S_v * H_v) + (size_t)seq * (S_v * H_v * n_tokens);
+            output[output_idx] = out_vec[d];
+        }
+        __syncthreads();
+    }
+}
+
+void ggml_cuda_op_delta_net_recurrent(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];  // q_tokens
+    const ggml_tensor * src1 = dst->src[1];  // k_tokens
+    const ggml_tensor * src2 = dst->src[2];  // v_tokens
+    const ggml_tensor * src3 = dst->src[3];  // g_tokens_exp
+    const ggml_tensor * src4 = dst->src[4];  // beta_tokens
+    const ggml_tensor * src5 = dst->src[5];  // state
+    
+    const int64_t H_v = (int64_t) dst->op_params[0];
+    const int64_t S_v = (int64_t) dst->op_params[2];
+    const int64_t n_tokens = (int64_t) dst->op_params[3];
+    const int64_t n_seqs = src0->ne[3];
+    
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    
+    // Verify tensor dimensions match CPU expectations
+    GGML_ASSERT(src0->ne[3] == n_seqs);  // q tensor
+    GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
+    GGML_ASSERT(src2->ne[3] == n_seqs);  // v tensor
+    GGML_ASSERT(src3->ne[3] == n_seqs);  // g tensor
+    GGML_ASSERT(src4->ne[3] == n_seqs);  // beta tensor
+    GGML_ASSERT(src5->ne[3] == n_seqs);  // state tensor
+    
+    const float * q_d = (const float *) src0->data;
+    const float * k_d = (const float *) src1->data;
+    const float * v_d = (const float *) src2->data;
+    const float * g_exp_d = (const float *) src3->data;
+    const float * beta_d = (const float *) src4->data;
+    const float * state_in_d = (const float *) src5->data;
+    
+    float * dst_d = (float *) dst->data;
+    float * output_d = dst_d;
+    float * state_out_d = dst_d + (S_v * H_v * n_tokens * n_seqs);
+    
+    cudaStream_t stream = ctx.stream();
+    
+    // Launch config
+    dim3 grid(H_v, n_seqs);
+    int block_x = 256;
+    if (S_v < 256) block_x = (S_v >= 128 ? 128 : (S_v >= 64 ? 64 : (S_v >= 32 ? 32 : 16)));
+    dim3 block(block_x, 1, 1);
+    
+    // Shared memory: 6 vectors (S_v each) + 2 scalars
+    size_t smem_size = (6 * (size_t)S_v + 2) * sizeof(float);
+    
+    delta_net_recurrent_f32_kernel<<<grid, block, smem_size, stream>>>(
+        q_d, k_d, v_d, g_exp_d, beta_d, state_in_d,
+        output_d, state_out_d,
+        S_v, H_v, n_tokens, n_seqs
+    );
+    
+    CUDA_CHECK(cudaGetLastError());
+}
+
+// Chunked kernel
+__global__ void delta_net_chunked_f32_kernel(
+    const float * __restrict__ q,
+    const float * __restrict__ k,
+    const float * __restrict__ v,
+    const float * __restrict__ g_cumsum,
+    const float * __restrict__ state_in,
+    const float * __restrict__ decay_mask,
+    const float * __restrict__ v_beta,
+    const float * __restrict__ k_beta,
+    const float * __restrict__ attn_in,
+    float * __restrict__ output,
+    float * __restrict__ state_out,
+    float * __restrict__ intermediate_global,  // Global memory for intermediate matrices
+    int S_v, int H_v, int n_tokens, int n_seqs, int chunk_size, int num_chunks) {
+
+    const int head = blockIdx.x;
+    const int seq  = blockIdx.y;
+    const int tid  = threadIdx.x;
+
+    if (head >= H_v || seq >= n_seqs) return;
+    
+    // Calculate offset for this block's intermediate storage
+    const size_t block_idx = (size_t)seq * H_v + head;
+    // Each block needs: 4*chunk_size*S_v (value, k_cumdecay, v_prime, v_new) + chunk_size*chunk_size (attn_new)
+    const size_t per_block_floats = 4 * (size_t)chunk_size * (size_t)S_v + (size_t)chunk_size * (size_t)chunk_size;
+    const size_t intermediate_offset = block_idx * per_block_floats;
+
+    // Offset helpers matching CPU layout
+    auto off_qkv = [&](const float * base, int h, int c, int i, int d) -> size_t {
+        // dims: [S_v, chunk_size, H_v*num_chunks, n_seqs]
+        const int hc = h * num_chunks + c;
+        return (size_t)d + (size_t)i * S_v + (size_t)hc * (size_t)(chunk_size * S_v) + (size_t)seq * (size_t)(chunk_size * S_v * H_v * num_chunks);
+    };
+    auto off_attn = [&](int h, int c, int i, int j) -> size_t {
+        // dims: [chunk_size, chunk_size, H_v*num_chunks, n_seqs]
+        const int hc = h * num_chunks + c;
+        return (size_t)j + (size_t)i * chunk_size + (size_t)hc * (size_t)(chunk_size * chunk_size) + (size_t)seq * (size_t)(chunk_size * chunk_size * H_v * num_chunks);
+    };
+    auto off_g = [&](int h, int c, int t) -> size_t {
+        // dims: [chunk_size, 1, H_v*num_chunks, n_seqs]
+        const int hc = h * num_chunks + c;
+        return (size_t)t + (size_t)hc * (size_t)chunk_size + (size_t)seq * (size_t)(chunk_size * H_v * num_chunks);
+    };
+    auto off_state = [&](int i, int j) -> size_t {
+        // dims: [S_v, S_v, H_v, n_seqs]
+        const size_t state_base = (size_t)head * (size_t)(S_v * S_v) + (size_t)seq * (size_t)(S_v * S_v * H_v);
+        return (size_t)j + (size_t)i * S_v + state_base;
+    };
+    auto off_out = [&](int global_token, int d) -> size_t {
+        // dims: [S_v, n_tokens, H_v, n_seqs]
+        // CPU layout: d + token * S_v + head * (n_tokens * S_v) + seq * (n_tokens * S_v * H_v)
+        return (size_t)d + (size_t)global_token * S_v + (size_t)head * (n_tokens * S_v) + (size_t)seq * (n_tokens * S_v * H_v);
+    };
+
+    // Shared memory: only attn_pre + row_buf (small, fits in shared memory)
+    extern __shared__ float shmem[];
+    float * attn_pre = shmem;
+    float * row_buf  = attn_pre + (size_t)chunk_size * chunk_size;
+    
+    // Global memory pointers for intermediate matrices (avoids shared memory overflow)
+    float * value      = intermediate_global + intermediate_offset;
+    float * k_cumdecay = value + (size_t)chunk_size * S_v;
+    float * v_prime    = k_cumdecay + (size_t)chunk_size * S_v;
+    float * v_new      = v_prime + (size_t)chunk_size * S_v;
+
+    // Initialize state_out from state_in
+    for (int idx = tid; idx < S_v * S_v; idx += blockDim.x) {
+        const int i = idx / S_v;
+        const int j = idx % S_v;
+        state_out[off_state(i, j)] = state_in[off_state(i, j)];
+    }
+    __syncthreads();
+
+    // Process each chunk
+    for (int chunk = 0; chunk < num_chunks; ++chunk) {
+        const int n_tokens_chunk = (chunk == num_chunks - 1 && n_tokens % chunk_size != 0)
+            ? (n_tokens % chunk_size)
+            : chunk_size;
+
+        // Initialize all attn_pre to zero first
+        for (int idx = tid; idx < chunk_size * chunk_size; idx += blockDim.x) {
+            attn_pre[idx] = 0.0f;
+        }
+        __syncthreads();
+        
+        // Copy attn_in tile to attn_pre (only valid n_tokens_chunk rows/cols)
+        for (int idx = tid; idx < n_tokens_chunk * n_tokens_chunk; idx += blockDim.x) {
+            int irow = idx / n_tokens_chunk;
+            int jcol = idx % n_tokens_chunk;
+            attn_pre[irow * chunk_size + jcol] = LDG(&attn_in[off_attn(head, chunk, irow, jcol)]);
+        }
+        __syncthreads();
+
+        // Triangular updates: for i in 1..n_tokens_chunk-1
+        // Python: attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+        // where row = attn[..., i, :i] and sub = attn[..., :i, :i]
+        // CPU copies row first, then sub, to avoid reading modified values
+        for (int irow = 1; irow < n_tokens_chunk; ++irow) {
+            // Step 1: Copy row = attn[irow, 0:irow] into row_buf
+            for (int k = tid; k < irow; k += blockDim.x) {
+                row_buf[k] = attn_pre[irow * chunk_size + k];
+            }
+            __syncthreads();
+            
+            // Step 2: Compute new values for attn[irow, 0:irow]
+            // The sub matrix attn[:irow, :irow] is read from the CURRENT attn_pre
+            // (which contains updates from previous irow iterations)
+            for (int j = tid; j < irow; j += blockDim.x) {
+                // Compute sum_k (row[k] * sub[k, j]) where k in [0, irow)
+                // sub[k, j] = attn_pre[k, j] for k < irow, j < irow
+                float sum = 0.0f;
+                for (int k = 0; k < irow; ++k) {
+                    sum += row_buf[k] * attn_pre[k * chunk_size + j];
+                }
+                
+                // Update: attn[irow, j] = row[j] + sum
+                attn_pre[irow * chunk_size + j] = row_buf[j] + sum;
+            }
+            __syncthreads();
+        }
+        // Add identity to diagonal
+        for (int d = tid; d < n_tokens_chunk; d += blockDim.x) {
+            attn_pre[d * chunk_size + d] += 1.0f;
+        }
+        __syncthreads();
+
+        // ========== OPTIMIZATION: Precompute intermediate matrices in global memory ==========
+        // This eliminates massive redundant computation!
+        // Note: value, k_cumdecay, v_prime, v_new already declared above using global memory
+        
+        // Precompute g_exp for all tokens in this chunk and keep it in shared row_buf
+        float * g_exp_buf = row_buf;
+        for (int t = tid; t < n_tokens_chunk; t += blockDim.x) {
+            g_exp_buf[t] = __expf(g_cumsum[off_g(head, chunk, t)]);
+        }
+        __syncthreads();
+
+        // Compute value = attn_pre @ v_beta [n_tokens_chunk x S_v]
+        for (int idx = tid; idx < n_tokens_chunk * S_v; idx += blockDim.x) {
+            const int row = idx / S_v;
+            const int col = idx % S_v;
+            float sum = 0.0f;
+            const float * __restrict__ pv = &v_beta[off_qkv(v_beta, head, chunk, 0, col)];
+            #pragma unroll 4
+            for (int k = 0; k < n_tokens_chunk; ++k) {
+                sum = FMA(attn_pre[row * chunk_size + k], LDG(pv), sum);
+                pv += (size_t)S_v;
+            }
+            value[row * S_v + col] = sum;
+        }
+        __syncthreads();
+        
+        // Compute k_cumdecay = attn_pre @ (k_beta * exp(g)) [n_tokens_chunk x S_v]
+        for (int idx = tid; idx < n_tokens_chunk * S_v; idx += blockDim.x) {
+            const int row = idx / S_v;
+            const int col = idx % S_v;
+            float sum = 0.0f;
+            const float * __restrict__ pk = &k_beta[off_qkv(k_beta, head, chunk, 0, col)];
+            #pragma unroll 4
+            for (int k = 0; k < n_tokens_chunk; ++k) {
+                sum = FMA(attn_pre[row * chunk_size + k], LDG(pk) * g_exp_buf[k], sum);
+                pk += (size_t)S_v;
+            }
+            k_cumdecay[row * S_v + col] = sum;
+        }
+        __syncthreads();
+        
+        // Compute v_prime = k_cumdecay @ state [n_tokens_chunk x S_v]
+        for (int idx = tid; idx < n_tokens_chunk * S_v; idx += blockDim.x) {
+            const int row = idx / S_v;
+            const int col = idx % S_v;
+            float sum = 0.0f;
+            const float * __restrict__ pstate_col = &state_out[(size_t)col + (size_t)head * (size_t)(S_v * S_v) + (size_t)seq * (size_t)(S_v * S_v * H_v)];
+            #pragma unroll 4
+            for (int k = 0; k < S_v; ++k) {
+                sum = FMA(k_cumdecay[row * S_v + k], pstate_col[(size_t)k * (size_t)S_v], sum);
+            }
+            v_prime[row * S_v + col] = sum;
+        }
+        __syncthreads();
+        
+        // Compute v_new = value - v_prime [n_tokens_chunk x S_v]
+        for (int idx = tid; idx < n_tokens_chunk * S_v; idx += blockDim.x) {
+            v_new[idx] = value[idx] - v_prime[idx];
+        }
+        __syncthreads();
+        
+        // ========== OPTIMIZATION 2: Precompute q@k attention matrix ==========
+        // Allocate space for attn_new in global memory to avoid recomputing q@k
+        float * attn_new = v_new + (size_t)chunk_size * S_v;  // Reuse space after v_new computation
+        
+        // Compute attn_new = (q @ k.T) * decay_mask [n_tokens_chunk x n_tokens_chunk]
+        // Each thread computes multiple elements
+        for (int idx = tid; idx < n_tokens_chunk * n_tokens_chunk; idx += blockDim.x) {
+            const int i = idx / n_tokens_chunk;
+            const int j = idx % n_tokens_chunk;
+            
+            if (j <= i) {  // Only lower triangular (causal mask)
+                float qk_dot = 0.0f;
+                const float * __restrict__ pq = &q[off_qkv(q, head, chunk, i, 0)];
+                const float * __restrict__ pk = &k[off_qkv(k, head, chunk, j, 0)];
+                int d = 0;
+                for (; d + 3 < S_v; d += 4) {
+                    qk_dot = FMA(LDG(pq + d + 0), LDG(pk + d + 0), qk_dot);
+                    qk_dot = FMA(LDG(pq + d + 1), LDG(pk + d + 1), qk_dot);
+                    qk_dot = FMA(LDG(pq + d + 2), LDG(pk + d + 2), qk_dot);
+                    qk_dot = FMA(LDG(pq + d + 3), LDG(pk + d + 3), qk_dot);
+                }
+                for (; d < S_v; ++d) {
+                    qk_dot = FMA(LDG(pq + d), LDG(pk + d), qk_dot);
+                }
+                attn_new[i * chunk_size + j] = qk_dot * LDG(&decay_mask[off_attn(head, chunk, i, j)]);
+            } else {
+                attn_new[i * chunk_size + j] = 0.0f;  // Upper triangular is zero
+            }
+        }
+        __syncthreads();
+        
+        
+        // ========== Now compute output using PRECOMPUTED matrices ==========
+        for (int idx = tid; idx < n_tokens_chunk * S_v; idx += blockDim.x) {
+            const int row = idx / S_v;
+            const int col = idx % S_v;
+            
+            // attn_inter = (q * exp(g)) @ state - use precomputed g_exp
+            float attn_inter = 0.0f;
+            const float g_exp = g_exp_buf[row];
+            const float * __restrict__ pqrow = &q[off_qkv(q, head, chunk, row, 0)];
+            const float * __restrict__ pstate_col = &state_out[(size_t)col + (size_t)head * (size_t)(S_v * S_v) + (size_t)seq * (size_t)(S_v * S_v * H_v)];
+            #pragma unroll 4
+            for (int k_idx = 0; k_idx < S_v; ++k_idx) {
+                attn_inter = FMA(LDG(pqrow + k_idx) * g_exp, pstate_col[(size_t)k_idx * (size_t)S_v], attn_inter);
+            }
+            
+            // core_attn_out = attn_new @ v_new using PRECOMPUTED attn_new and v_new!
+            float core_attn_out = 0.0f;
+            for (int k_idx = 0; k_idx <= row; ++k_idx) {
+                // Use precomputed attn_new - NO q@k computation!
+                core_attn_out += attn_new[row * chunk_size + k_idx] * v_new[k_idx * S_v + col];
+            }
+            
+            const int global_token = chunk * chunk_size + row;
+            if (global_token < n_tokens) {
+                output[off_out(global_token, col)] = attn_inter + core_attn_out;
+            }
+        }
+        __syncthreads();
+
+        // ========== Update state using PRECOMPUTED v_new ==========
+        // Precompute g_diff_exp values (reuse g_exp_buf)
+        float g_last = g_exp_buf[n_tokens_chunk - 1];
+        float * g_diff_buf = g_exp_buf;  // Reuse buffer
+        // Use exp of the difference to avoid divide-by-zero/underflow issues
+        const float g_last_log = g_cumsum[off_g(head, chunk, n_tokens_chunk - 1)];
+        for (int t = tid; t < n_tokens_chunk; t += blockDim.x) {
+            g_diff_buf[t] = __expf(g_last_log - g_cumsum[off_g(head, chunk, t)]);
+        }
+        __syncthreads();
+        
+        for (int idx = tid; idx < S_v * S_v; idx += blockDim.x) {
+            const int i = idx / S_v;
+            const int j = idx % S_v;
+            
+            float new_state_val = state_out[off_state(i, j)] * g_last;
+            
+            // Use precomputed v_new and g_diff - NO exp() calls in loop!
+            const float * __restrict__ pk_tok = &k[off_qkv(k, head, chunk, 0, i)];
+            #pragma unroll 4
+            for (int t = 0; t < n_tokens_chunk; ++t) {
+                new_state_val = FMA(LDG(pk_tok), g_diff_buf[t] * v_new[t * S_v + j], new_state_val);
+                pk_tok += (size_t)S_v;
+            }
+            
+            state_out[off_state(i, j)] = new_state_val;
+        }
+        __syncthreads();
+    }
+}
+
+void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    // inputs
+    const ggml_tensor * src0 = dst->src[0];  // q
+    const ggml_tensor * src1 = dst->src[1];  // k
+    const ggml_tensor * src2 = dst->src[2];  // v
+    const ggml_tensor * src3 = dst->src[3];  // g (cumsum)
+    const ggml_tensor * src4 = dst->src[4];  // state
+    const ggml_tensor * src5 = dst->src[5];  // decay_mask
+    const ggml_tensor * src6 = dst->src[6];  // v_beta
+    const ggml_tensor * src7 = dst->src[7];  // k_beta
+    const ggml_tensor * src8 = dst->src[8];  // attn (pre)
+
+    const int H_v       = (int) dst->op_params[0];
+    const int S_v       = (int) dst->op_params[2];
+    const int n_tokens  = (int) dst->op_params[3];
+    const int n_seqs    = (int) src0->ne[3];
+    const int chunk_size = (int) GGML_DELTA_NET_CHUNK;
+    const int pad_size   = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int num_chunks = (n_tokens + pad_size) / chunk_size;
+
+    const float * q_d   = (const float *) src0->data;
+    const float * k_d   = (const float *) src1->data;
+    const float * v_d   = (const float *) src2->data;
+    const float * g_d   = (const float *) src3->data;
+    const float * state_in_d = (const float *) src4->data;
+    const float * decay_d = (const float *) src5->data;
+    const float * vbeta_d = (const float *) src6->data;
+    const float * kbeta_d = (const float *) src7->data;
+    const float * attn_in_d = (const float *) src8->data;
+
+    float * dst_d = (float *) dst->data;
+    float * output_d   = dst_d;
+    float * state_out_d = dst_d + (size_t) S_v * H_v * n_tokens * n_seqs;
+
+    dim3 grid(H_v, n_seqs);
+    int block_x2 = 256;
+    if (S_v < 256) block_x2 = (S_v >= 128 ? 128 : (S_v >= 64 ? 64 : (S_v >= 32 ? 32 : 16)));
+    dim3 block(block_x2, 1, 1);
+
+    cudaStream_t stream = ctx.stream();
+    
+    // Allocate global memory for intermediate matrices per block:
+    // - value, k_cumdecay, v_prime, v_new: 4 * chunk_size * S_v
+    // - attn_new: chunk_size * chunk_size (reuses space after v_new)
+    // Total: max(4 * chunk_size * S_v, 4 * chunk_size * S_v + chunk_size * chunk_size)
+    size_t intermediate_size = (4 * (size_t)chunk_size * S_v + (size_t)chunk_size * chunk_size) * sizeof(float);
+    ggml_cuda_pool_alloc<float> intermediate_alloc(ctx.pool(), intermediate_size * H_v * n_seqs);
+    float * intermediate_d = intermediate_alloc.get();
+    
+    // Shared memory per block: only attn_pre + row_buf (much smaller!)
+    size_t smem = ((size_t)chunk_size * chunk_size + chunk_size) * sizeof(float);
+    
+    delta_net_chunked_f32_kernel<<<grid, block, smem, stream>>>(
+        q_d, k_d, v_d, g_d, state_in_d, decay_d, vbeta_d, kbeta_d, attn_in_d,
+        output_d, state_out_d, intermediate_d,
+        S_v, H_v, n_tokens, n_seqs, chunk_size, num_chunks
+    );
+    
+    CUDA_CHECK(cudaGetLastError());
+}
+

+ 5 - 0
ggml/src/ggml-cuda/delta-net.cuh

@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_delta_net_recurrent(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+

+ 20 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -43,6 +43,9 @@
 #include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
 #include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/cumsum.cuh"
+#include "ggml-cuda/tri.cuh"
+#include "ggml-cuda/delta-net.cuh"
 #include "ggml-cuda/mean.cuh"
 #include "ggml-cuda/mean.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/unary.cuh"
@@ -2445,6 +2448,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_DIAG_MASK_INF:
             ggml_cuda_op_diag_mask_inf(ctx, dst);
             ggml_cuda_op_diag_mask_inf(ctx, dst);
             break;
             break;
+        case GGML_OP_TRI:
+            ggml_cuda_op_tri(ctx, dst);
+            break;
         case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX:
             ggml_cuda_op_soft_max(ctx, dst);
             ggml_cuda_op_soft_max(ctx, dst);
             break;
             break;
@@ -2487,6 +2493,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
             break;
+        case GGML_OP_CUMSUM:
+            ggml_cuda_op_cumsum(ctx, dst);
+            break;
         case GGML_OP_MEAN:
         case GGML_OP_MEAN:
             ggml_cuda_op_mean(ctx, dst);
             ggml_cuda_op_mean(ctx, dst);
             break;
             break;
@@ -2514,6 +2523,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_RWKV_WKV7:
         case GGML_OP_RWKV_WKV7:
             ggml_cuda_op_rwkv_wkv7(ctx, dst);
             ggml_cuda_op_rwkv_wkv7(ctx, dst);
             break;
             break;
+        case GGML_OP_DELTA_NET:
+            ggml_cuda_op_delta_net(ctx, dst);
+            break;
+        case GGML_OP_DELTA_NET_RECURRENT:
+            ggml_cuda_op_delta_net_recurrent(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
             break;
@@ -3569,6 +3584,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return true;
             return true;
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_DIAG_MASK_INF:
             return true;
             return true;
+        case GGML_OP_TRI:
+            return true;
         case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX:
             return true;
             return true;
         case GGML_OP_SOFT_MAX_BACK: {
         case GGML_OP_SOFT_MAX_BACK: {
@@ -3608,7 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_RWKV_WKV7:
         case GGML_OP_RWKV_WKV7:
+        case GGML_OP_DELTA_NET_RECURRENT:
             return true;
             return true;
+        case GGML_OP_DELTA_NET:
+            return true;  // Chunked version not implemented yet, use CPU
         case GGML_OP_FLASH_ATTN_EXT:
         case GGML_OP_FLASH_ATTN_EXT:
             return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
             return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS:

+ 162 - 0
ggml/src/ggml-cuda/tri.cu

@@ -0,0 +1,162 @@
+#include "tri.cuh"
+
+// Optimized: process 4 elements per thread with float4
+template<ggml_tri_type type>
+__global__ void tri_f32_kernel(const float * __restrict__ x, float * __restrict__ dst,
+                                int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3,
+                                int64_t nb0, int64_t nb1, int64_t nb2, int64_t nb3,
+                                int64_t dst_nb0, int64_t dst_nb1, int64_t dst_nb2, int64_t dst_nb3,
+                                float constant, bool keep_org_val) {
+    
+    const int64_t i03 = blockIdx.z;
+    const int64_t i02 = blockIdx.y;
+    const int64_t i01 = blockIdx.x;
+    
+    if (i03 >= ne3 || i02 >= ne2 || i01 >= ne1) return;
+    
+    const float * src_row = (const float *)((const char *)x + i01*nb1 + i02*nb2 + i03*nb3);
+    float * dst_row = (float *)((char *)dst + i01*dst_nb1 + i02*dst_nb2 + i03*dst_nb3);
+    
+    const int row = i01;
+    const int tid = threadIdx.x;
+    
+    // Vectorized: process 4 elements at once when possible
+    const int64_t vec_count = ne0 / 4;
+    const int64_t remainder = ne0 % 4;
+    
+    // Process 4 elements at a time
+    for (int64_t vec_idx = tid; vec_idx < vec_count; vec_idx += blockDim.x) {
+        const int64_t col_base = vec_idx * 4;
+        
+        // Load 4 values
+        float4 src_val = *reinterpret_cast<const float4*>(&src_row[col_base]);
+        float4 dst_val;
+        
+        // Process each element
+        #pragma unroll
+        for (int i = 0; i < 4; ++i) {
+            const int col = col_base + i;
+            bool cmp;
+            switch (type) {
+                case GGML_TRI_TYPE_LOWER:      cmp = col < row;  break;
+                case GGML_TRI_TYPE_LOWER_DIAG: cmp = col <= row; break;
+                case GGML_TRI_TYPE_UPPER:      cmp = col > row;  break;
+                case GGML_TRI_TYPE_UPPER_DIAG: cmp = col >= row; break;
+                default: cmp = false; break;
+            }
+            (&dst_val.x)[i] = cmp ? (keep_org_val ? (&src_val.x)[i] : constant) : 0.0f;
+        }
+        
+        // Store 4 values
+        *reinterpret_cast<float4*>(&dst_row[col_base]) = dst_val;
+    }
+    
+    // Handle remainder elements
+    for (int64_t i = tid + vec_count * 4; i < ne0; i += blockDim.x) {
+        const int col = i;
+        bool cmp;
+        switch (type) {
+            case GGML_TRI_TYPE_LOWER:      cmp = col < row;  break;
+            case GGML_TRI_TYPE_LOWER_DIAG: cmp = col <= row; break;
+            case GGML_TRI_TYPE_UPPER:      cmp = col > row;  break;
+            case GGML_TRI_TYPE_UPPER_DIAG: cmp = col >= row; break;
+            default: cmp = false; break;
+        }
+        dst_row[col] = cmp ? (keep_org_val ? src_row[col] : constant) : 0.0f;
+    }
+}
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->ne[0] == src0->ne[1]); // Square matrices
+    
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+    
+    cudaStream_t stream = ctx.stream();
+    
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
+    
+    const int64_t nb0 = src0->nb[0];
+    const int64_t nb1 = src0->nb[1];
+    const int64_t nb2 = src0->nb[2];
+    const int64_t nb3 = src0->nb[3];
+    
+    const int64_t dst_nb0 = dst->nb[0];
+    const int64_t dst_nb1 = dst->nb[1];
+    const int64_t dst_nb2 = dst->nb[2];
+    const int64_t dst_nb3 = dst->nb[3];
+    
+    ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0];
+    float constant = ggml_get_op_params_f32(dst, 1);
+    bool keep_org_val = isnan(constant);
+    
+    // Launch kernel
+    // Grid: (ne1 rows, ne2 batches, ne3 batches)
+    // Block: (ne0 cols per row)
+    dim3 grid(ne1, ne2, ne3);
+    int block_size = min((int)ne0, 1024);
+    
+    // We need to launch multiple blocks per row if ne0 > 1024
+    int num_blocks_per_row = (ne0 + block_size - 1) / block_size;
+    
+    if (num_blocks_per_row > 1) {
+        // For very wide matrices, use 2D grid with multiple blocks per row
+        dim3 grid_2d(ne1 * num_blocks_per_row, ne2, ne3);
+        
+        switch (ttype) {
+            case GGML_TRI_TYPE_LOWER:
+                tri_f32_kernel<GGML_TRI_TYPE_LOWER><<<grid_2d, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+            case GGML_TRI_TYPE_LOWER_DIAG:
+                tri_f32_kernel<GGML_TRI_TYPE_LOWER_DIAG><<<grid_2d, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+            case GGML_TRI_TYPE_UPPER:
+                tri_f32_kernel<GGML_TRI_TYPE_UPPER><<<grid_2d, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+            case GGML_TRI_TYPE_UPPER_DIAG:
+                tri_f32_kernel<GGML_TRI_TYPE_UPPER_DIAG><<<grid_2d, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+        }
+    } else {
+        // Standard case: one block per row
+        switch (ttype) {
+            case GGML_TRI_TYPE_LOWER:
+                tri_f32_kernel<GGML_TRI_TYPE_LOWER><<<grid, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+            case GGML_TRI_TYPE_LOWER_DIAG:
+                tri_f32_kernel<GGML_TRI_TYPE_LOWER_DIAG><<<grid, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+            case GGML_TRI_TYPE_UPPER:
+                tri_f32_kernel<GGML_TRI_TYPE_UPPER><<<grid, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+            case GGML_TRI_TYPE_UPPER_DIAG:
+                tri_f32_kernel<GGML_TRI_TYPE_UPPER_DIAG><<<grid, block_size, 0, stream>>>(
+                    src0_d, dst_d, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                    dst_nb0, dst_nb1, dst_nb2, dst_nb3, constant, keep_org_val);
+                break;
+        }
+    }
+}
+

+ 4 - 0
ggml/src/ggml-cuda/tri.cuh

@@ -0,0 +1,4 @@
+#include "common.cuh"
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+