Explorar el Código

metal : add cumsum (#17305)

Georgi Gerganov hace 2 meses
padre
commit
1a139644a8

+ 38 - 0
ggml/src/ggml-metal/ggml-metal-device.cpp

@@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
     return res;
 }
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
+    GGML_ASSERT(op->op == GGML_OP_CUMSUM);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
+    GGML_ASSERT(op->op == GGML_OP_CUMSUM);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
 

+ 2 - 0
ggml/src/ggml-metal/ggml-metal-device.h

@@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary             (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows          (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk        (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add        (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-device.m

@@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_SUM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_CUMSUM:
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:

+ 39 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -612,6 +612,45 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_sum_rows;
 
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  net0;
+    int64_t  net1;
+    int64_t  net2;
+    int64_t  net3;
+    uint64_t nbt0;
+    uint64_t nbt1;
+    uint64_t nbt2;
+    uint64_t nbt3;
+    bool     outb;
+} ggml_metal_kargs_cumsum_blk;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  net0;
+    int64_t  net1;
+    int64_t  net2;
+    int64_t  net3;
+    uint64_t nbt0;
+    uint64_t nbt1;
+    uint64_t nbt2;
+    uint64_t nbt3;
+} ggml_metal_kargs_cumsum_add;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;

+ 184 - 63
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_sum_rows(ctx, idx);
             } break;
+        case GGML_OP_CUMSUM:
+            {
+                n_fuse = ggml_metal_op_cumsum(ctx, idx);
+            } break;
         case GGML_OP_SOFT_MAX:
             {
                 n_fuse = ggml_metal_op_soft_max(ctx, idx);
@@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
 
@@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
@@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float scale;
     float bias;
@@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float min;
     float max;
@@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     int64_t n = ggml_nelements(op);
 
@@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     if (op->src[1]) {
         GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
@@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
 
     const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
 
-    //[encoder setComputePipelineState:pipeline];
-    //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-    //if (src1) {
-    //    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-    //} else {
-    //    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-    //}
-    //[encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-    //[encoder setBytes:&args length:sizeof(args) atIndex:3];
-
-    //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_kargs_sum_rows args = {
         /*.ne00 =*/ ne00,
@@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
 
     const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
 
-    //[encoder setComputePipelineState:pipeline];
-    //[encoder setBytes:&args length:sizeof(args) atIndex:0];
-    //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-    //[encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-    //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-    //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
+
+    int nth = 1;
+    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
+        nth *= 2;
+    }
+
+    GGML_ASSERT(ne00 <= nth*nth);
+
+    const int64_t net0 = (ne00 + nth - 1) / nth;
+    const int64_t net1 = ne01;
+    const int64_t net2 = ne02;
+    const int64_t net3 = ne03;
+
+    const uint64_t nbt0 = sizeof(float);
+    const uint64_t nbt1 = net0*nbt0;
+    const uint64_t nbt2 = net1*nbt1;
+    const uint64_t nbt3 = net2*nbt2;
+
+    const size_t smem = GGML_PAD(32*sizeof(float), 16);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    ggml_metal_buffer_id bid_tmp = bid_dst;
+    bid_tmp.offs += ggml_nbytes(op);
+
+    {
+        ggml_metal_kargs_cumsum_blk args = {
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.net0 =*/ net0,
+            /*.net1 =*/ net1,
+            /*.net2 =*/ net2,
+            /*.net3 =*/ net3,
+            /*.nbt0 =*/ nbt0,
+            /*.nbt1 =*/ nbt1,
+            /*.nbt2 =*/ nbt2,
+            /*.nbt3 =*/ nbt3,
+            /*.outb =*/ ne00 > nth,
+        };
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
+        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  2);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
+
+        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
+    }
+
+    if (ne00 > nth) {
+        ggml_metal_op_concurrency_reset(ctx);
+
+        {
+            ggml_metal_kargs_cumsum_blk args = {
+                /*.ne00 =*/ net0,
+                /*.ne01 =*/ net1,
+                /*.ne02 =*/ net2,
+                /*.ne03 =*/ net3,
+                /*.nb00 =*/ nbt0,
+                /*.nb01 =*/ nbt1,
+                /*.nb02 =*/ nbt2,
+                /*.nb03 =*/ nbt3,
+                /*.net0 =*/ net0,
+                /*.net1 =*/ net1,
+                /*.net2 =*/ net2,
+                /*.net3 =*/ net3,
+                /*.nbt0 =*/ nbt0,
+                /*.nbt1 =*/ nbt1,
+                /*.nbt2 =*/ nbt2,
+                /*.nbt3 =*/ nbt3,
+                /*.outb =*/ false,
+            };
+
+            ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
+            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);
+            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 2);
+            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 3);
+
+            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+            ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
+        }
+
+        ggml_metal_op_concurrency_reset(ctx);
+
+        {
+            ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
+
+            ggml_metal_kargs_cumsum_add args = {
+                /*.ne00 =*/ ne00,
+                /*.ne01 =*/ ne01,
+                /*.ne02 =*/ ne02,
+                /*.ne03 =*/ ne03,
+                /*.nb00 =*/ nb00,
+                /*.nb01 =*/ nb01,
+                /*.nb02 =*/ nb02,
+                /*.nb03 =*/ nb03,
+                /*.net0 =*/ net0,
+                /*.net1 =*/ net1,
+                /*.net2 =*/ net2,
+                /*.net3 =*/ net3,
+                /*.nbt0 =*/ nbt0,
+                /*.nbt1 =*/ nbt1,
+                /*.nbt2 =*/ nbt2,
+                /*.nbt3 =*/ nbt3,
+            };
+
+            ggml_metal_encoder_set_pipeline(enc, pipeline_add);
+            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);
+            ggml_metal_encoder_set_buffer  (enc, bid_dst, 2);
+
+            ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
+        }
+    }
+
+    return 1;
+}
+
 int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
 
@@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
 
@@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float scale;
     float max_bias;
@@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_kargs_ssm_conv args = {
         /*.ne00 =*/ ne00,
@@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const ggml_tensor * src3 = op->src[3];
     const ggml_tensor * src4 = op->src[4];
@@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
     const int64_t T = op->src[0]->ne[2];
@@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
 
@@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int32_t * opts = op->op_params;
     ggml_op_pool op_pool = (ggml_op_pool) opts[0];
@@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     GGML_ASSERT(ne00 == ne10);
 
@@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     // src2 = ids
     GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
@@ -2689,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float eps;
     memcpy(&eps, op->op_params, sizeof(float));
@@ -2737,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int32_t ngrp = ((const int32_t *) op->op_params)[0];
 
@@ -2792,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float eps;
     memcpy(&eps, op->op_params, sizeof(float));
@@ -2928,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     // make sure we have one or more position id(ne10) per token(ne02)
     GGML_ASSERT(ne10 % ne02 == 0);
@@ -3022,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int32_t s0 = ((const int32_t *)(op->op_params))[0];
     const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@@ -3172,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int32_t s0 = ((const int32_t *)(op->op_params))[0];
 
@@ -3217,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int32_t s0 = ((const int32_t *)(op->op_params))[0];
 
@@ -3271,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const float sf0 = (float)ne0/op->src[0]->ne[0];
     const float sf1 = (float)ne1/op->src[0]->ne[1];
@@ -3324,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_kargs_pad args = {
         /*.ne00 =*/ ne00,
@@ -3368,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_kargs_pad_reflect_1d args = {
         /*.ne00 =*/ ne00,
@@ -3412,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_t enc = ctx->enc;
 
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float start;
     float step;
@@ -3430,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
 
-    //[encoder setComputePipelineState:pipeline];
-    //[encoder setBuffer:id_dst  offset:offs_dst  atIndex:0];
-    //[encoder setBytes:&args length:sizeof(args) atIndex:1];
-
-    //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op), 1);
@@ -3454,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     const int dim        = op->op_params[0];
     const int max_period = op->op_params[1];
@@ -3488,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_kargs_argmax args = {
         /*.ne00 = */ ne00,
@@ -3529,7 +3650,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
 
@@ -3539,7 +3660,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
         nth *= 2;
     }
 
-    const int nptg = (ne00 + nth - 1)/nth;
+    const int npr = (ne00 + nth - 1)/nth;
 
     // Metal kernels require the buffer size to be multiple of 16 bytes
     // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
@@ -3551,7 +3672,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
     ggml_metal_buffer_id bid_tmp = bid_dst;
     bid_tmp.offs += ggml_nbytes(op);
 
-    if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
+    if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
         std::swap(bid_dst, bid_tmp);
     }
 
@@ -3573,7 +3694,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
 
     ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
 
@@ -3626,7 +3747,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     float slope;
     memcpy(&slope, op->op_params, sizeof(float));
@@ -3662,7 +3783,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
 
@@ -3698,7 +3819,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
 

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-ops.h

@@ -52,6 +52,7 @@ int ggml_metal_op_unary             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_glu               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_sum               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_sum_rows          (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_cumsum            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_get_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);

+ 1 - 0
ggml/src/ggml-metal/ggml-metal.cpp

@@ -197,6 +197,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
                 res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
                 res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
             } break;
+        case GGML_OP_CUMSUM:
         case GGML_OP_ARGSORT:
             {
                 res *= 2;

+ 126 - 15
ggml/src/ggml-metal/ggml-metal.metal

@@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
 template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
 template [[host_name("kernel_mean_f32")]]     kernel kernel_sum_rows_t kernel_sum_rows<true>;
 
+template<typename T>
+kernel void kernel_cumsum_blk(
+        constant ggml_metal_kargs_cumsum_blk & args,
+        device const char * src0,
+        device       char * tmp,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int ib = tgpig[0]/args.ne01;
+
+    const int i00 = ib*ntg.x;
+    const int i01 = tgpig[0]%args.ne01;
+    const int i02 = tgpig[1];
+    const int i03 = tgpig[2];
+
+    device const float * src0_row = (device const float *) (src0 +
+            args.nb01*i01 +
+            args.nb02*i02 +
+            args.nb03*i03);
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+
+    float v = 0.0f;
+
+    if (i00 + tpitg.x < args.ne00) {
+        v = src0_row[i00 + tpitg.x];
+    }
+
+    float s = simd_prefix_inclusive_sum(v);
+
+    if (tiisg == N_SIMDWIDTH - 1) {
+        shmem_f32[sgitg] = s;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    s += shmem_f32[sgitg];
+
+    device float * dst_row = (device float *) dst +
+        args.ne00*i01 +
+        args.ne00*args.ne01*i02 +
+        args.ne00*args.ne01*args.ne02*i03;
+
+    if (i00 + tpitg.x < args.ne00) {
+        dst_row[i00 + tpitg.x] = s;
+    }
+
+    if (args.outb && tpitg.x == ntg.x - 1) {
+        device float * tmp_row = (device float *) tmp +
+            args.net0*i01 +
+            args.net0*args.net1*i02 +
+            args.net0*args.net1*args.net2*i03;
+
+        tmp_row[ib] = s;
+    }
+}
+
+typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
+
+template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
+
+template<typename T>
+kernel void kernel_cumsum_add(
+        constant ggml_metal_kargs_cumsum_add & args,
+        device const char * tmp,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int ib = tgpig[0]/args.ne01;
+
+    if (ib == 0) {
+        return;
+    }
+
+    const int i00 = ib*ntg.x;
+    const int i01 = tgpig[0]%args.ne01;
+    const int i02 = tgpig[1];
+    const int i03 = tgpig[2];
+
+    device const float * tmp_row = (device const float *) (tmp +
+            args.nbt1*i01 +
+            args.nbt2*i02 +
+            args.nbt3*i03);
+
+    device float * dst_row = (device float *) dst +
+        args.ne00*i01 +
+        args.ne00*args.ne01*i02 +
+        args.ne00*args.ne01*args.ne02*i03;
+
+    if (i00 + tpitg.x < args.ne00) {
+        dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
+    }
+}
+
+typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
+
+template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
+
 template<typename T>
 kernel void kernel_soft_max(
         constant ggml_metal_kargs_soft_max & args,
@@ -4543,7 +4654,7 @@ typedef void (argsort_t)(
         constant   ggml_metal_kargs_argsort & args,
         device   const char * src0,
         device      int32_t * dst,
-        threadgroup int32_t * smem_i32 [[threadgroup(0)]],
+        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]);
@@ -4553,7 +4664,7 @@ kernel void kernel_argsort_f32_i32(
         constant   ggml_metal_kargs_argsort & args,
         device   const char * src0,
         device      int32_t * dst,
-        threadgroup int32_t * smem_i32 [[threadgroup(0)]],
+        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
@@ -4565,10 +4676,10 @@ kernel void kernel_argsort_f32_i32(
     const int i02 =  tgpig[1];
     const int i03 =  tgpig[2];
 
-    device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
+    device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
 
     // initialize indices
-    smem_i32[col] = i00 + col;
+    shmem_i32[col] = i00 + col;
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
@@ -4577,20 +4688,20 @@ kernel void kernel_argsort_f32_i32(
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (smem_i32[col] >= args.ne00 ||
-                       (smem_i32[ixj] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
-                            x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
+                    if (shmem_i32[col] >= args.ne00 ||
+                       (shmem_i32[ixj] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
+                            src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
+                            src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
                     ) {
-                        SWAP(smem_i32[col], smem_i32[ixj]);
+                        SWAP(shmem_i32[col], shmem_i32[ixj]);
                     }
                 } else {
-                    if (smem_i32[ixj] >= args.ne00 ||
-                       (smem_i32[col] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
-                            x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
+                    if (shmem_i32[ixj] >= args.ne00 ||
+                       (shmem_i32[col] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
+                            src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
+                            src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
                     ) {
-                        SWAP(smem_i32[col], smem_i32[ixj]);
+                        SWAP(shmem_i32[col], shmem_i32[ixj]);
                     }
                 }
             }
@@ -4603,7 +4714,7 @@ kernel void kernel_argsort_f32_i32(
     if (i00 + col < args.ne00) {
         dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
 
-        dst[col] = smem_i32[col];
+        dst[col] = shmem_i32[col];
     }
 }
 

+ 14 - 1
tests/test-backend-ops.cpp

@@ -7558,7 +7558,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_arange());
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
-    test_cases.emplace_back(new test_cumsum());
+
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
 
     test_cases.emplace_back(new test_xielu());