Просмотр исходного кода

metal : support argsort for ne00 > 1024 (#17247)

* metal : refactor argsort

* cont : sort chunks

* cont : merge sorted buckets

* cont : cleanup
Georgi Gerganov 2 месяцев назад
Родитель
Сommit
45c6ef7307

+ 28 - 0
ggml/src/ggml-metal/ggml-metal-device.cpp

@@ -943,6 +943,34 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
     return res;
     return res;
 }
 }
 
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_ARGSORT);
+
+    char base[256];
+    char name[256];
+
+    ggml_sort_order order = (ggml_sort_order) op->op_params[0];
+
+    const char * order_str = "undefined";
+    switch (order) {
+        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
+        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
+        default: GGML_ABORT("fatal error");
+    };
+
+    snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         const struct ggml_tensor * op,

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-device.h

@@ -125,6 +125,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id         (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);

+ 0 - 2
ggml/src/ggml-metal/ggml-metal-device.m

@@ -904,8 +904,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_LEAKY_RELU:
             return op->src[0]->type == GGML_TYPE_F32;
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ARGSORT:
         case GGML_OP_ARGSORT:
-            // TODO: Support arbitrary column width
-            return op->src[0]->ne[0] <= 1024;
         case GGML_OP_ARANGE:
         case GGML_OP_ARANGE:
             return true;
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
         case GGML_OP_FLASH_ATTN_EXT:

+ 20 - 2
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -793,10 +793,28 @@ typedef struct {
 } ggml_metal_kargs_leaky_relu;
 } ggml_metal_kargs_leaky_relu;
 
 
 typedef struct {
 typedef struct {
-    int64_t  ncols;
-    int64_t  ncols_pad;
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
 } ggml_metal_kargs_argsort;
 } ggml_metal_kargs_argsort;
 
 
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  len;
+} ggml_metal_kargs_argsort_merge;
+
 typedef struct {
 typedef struct {
     int64_t  ne0;
     int64_t  ne0;
     float    start;
     float    start;

+ 69 - 12
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -3530,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
     ggml_metal_library_t lib = ctx->lib;
     ggml_metal_library_t lib = ctx->lib;
     ggml_metal_encoder_t enc = ctx->enc;
     ggml_metal_encoder_t enc = ctx->enc;
 
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
     GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
 
 
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
+
     // bitonic sort requires the number of elements to be power of 2
     // bitonic sort requires the number of elements to be power of 2
-    int64_t ne00_padded = 1;
-    while (ne00_padded < ne00) {
-        ne00_padded *= 2;
+    int nth = 1;
+    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        nth *= 2;
     }
     }
 
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
-
-    const int64_t nrows = ggml_nrows(op->src[0]);
+    const int nptg = (ne00 + nth - 1)/nth;
 
 
     // Metal kernels require the buffer size to be multiple of 16 bytes
     // Metal kernels require the buffer size to be multiple of 16 bytes
     // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
     // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
-    const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+    const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    ggml_metal_buffer_id bid_tmp = bid_dst;
+    bid_tmp.offs += ggml_nbytes(op);
+
+    if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
+        std::swap(bid_dst, bid_tmp);
+    }
 
 
     ggml_metal_kargs_argsort args = {
     ggml_metal_kargs_argsort args = {
-        /*.ncols =*/ ne00,
-        /*.ncols_pad =*/ ne00_padded
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
     };
     };
 
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
+
+    ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
+
+    int len = nth;
+
+    while (len < ne00) {
+        ggml_metal_op_concurrency_reset(ctx);
+
+        ggml_metal_kargs_argsort_merge args_merge = {
+            .ne00 = ne00,
+            .ne01 = ne01,
+            .ne02 = ne02,
+            .ne03 = ne03,
+            .nb00 = nb00,
+            .nb01 = nb01,
+            .nb02 = nb02,
+            .nb03 = nb03,
+            .len  = len,
+        };
+
+        // merges per row
+        const int nm = (ne00 + 2*len - 1) / (2*len);
+
+        const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
+        ggml_metal_encoder_set_bytes   (enc, &args_merge, sizeof(args_merge), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);
+
+        ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
+
+        std::swap(bid_dst, bid_tmp);
+
+        len <<= 1;
+    }
 
 
     return 1;
     return 1;
 }
 }

+ 4 - 0
ggml/src/ggml-metal/ggml-metal.cpp

@@ -197,6 +197,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
                 res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
                 res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
                 res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
                 res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
             } break;
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                res *= 2;
+            } break;
         default:
         default:
             break;
             break;
     }
     }

+ 137 - 27
ggml/src/ggml-metal/ggml-metal.metal

@@ -4541,69 +4541,179 @@ kernel void kernel_timestep_embedding_f32(
 // bitonic sort implementation following the CUDA kernels as reference
 // bitonic sort implementation following the CUDA kernels as reference
 typedef void (argsort_t)(
 typedef void (argsort_t)(
         constant   ggml_metal_kargs_argsort & args,
         constant   ggml_metal_kargs_argsort & args,
-        device  const float * x,
+        device   const char * src0,
         device      int32_t * dst,
         device      int32_t * dst,
-        threadgroup int32_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]]);
+        threadgroup int32_t * smem_i32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]);
 
 
 template<ggml_sort_order order>
 template<ggml_sort_order order>
 kernel void kernel_argsort_f32_i32(
 kernel void kernel_argsort_f32_i32(
         constant   ggml_metal_kargs_argsort & args,
         constant   ggml_metal_kargs_argsort & args,
-        device const float  * x,
+        device   const char * src0,
         device      int32_t * dst,
         device      int32_t * dst,
-        threadgroup int32_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]]) {
+        threadgroup int32_t * smem_i32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
     // bitonic sort
     // bitonic sort
-    int col = tpitg[0];
-    int row = tgpig[1];
+    const int col = tpitg[0];
 
 
-    if (col >= args.ncols_pad) return;
+    const int i00 = (tgpig[0]/args.ne01)*ntg.x;
+    const int i01 =  tgpig[0]%args.ne01;
+    const int i02 =  tgpig[1];
+    const int i03 =  tgpig[2];
 
 
-    device const float   * x_row   = x + row * args.ncols;
-    threadgroup int32_t  * dst_row = shared_values;
+    device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
 
 
     // initialize indices
     // initialize indices
-    dst_row[col] = col;
+    smem_i32[col] = i00 + col;
 
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
-    for (int k = 2; k <= args.ncols_pad; k *= 2) {
+    for (int k = 2; k <= ntg.x; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             int ixj = col ^ j;
             if (ixj > col) {
             if (ixj > col) {
                 if ((col & k) == 0) {
                 if ((col & k) == 0) {
-                    if (dst_row[col] >= args.ncols ||
-                        (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    if (smem_i32[col] >= args.ne00 ||
+                       (smem_i32[ixj] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
+                            x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
                     ) {
                     ) {
-                        SWAP(dst_row[col], dst_row[ixj]);
+                        SWAP(smem_i32[col], smem_i32[ixj]);
                     }
                     }
                 } else {
                 } else {
-                    if (dst_row[ixj] >= args.ncols ||
-                        (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    if (smem_i32[ixj] >= args.ne00 ||
+                       (smem_i32[col] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
+                            x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
                     ) {
                     ) {
-                        SWAP(dst_row[col], dst_row[ixj]);
+                        SWAP(smem_i32[col], smem_i32[ixj]);
                     }
                     }
                 }
                 }
             }
             }
+
             threadgroup_barrier(mem_flags::mem_threadgroup);
             threadgroup_barrier(mem_flags::mem_threadgroup);
         }
         }
     }
     }
 
 
     // copy the result to dst without the padding
     // copy the result to dst without the padding
-    if (col < args.ncols) {
-        dst[row * args.ncols + col] = dst_row[col];
+    if (i00 + col < args.ne00) {
+        dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
+
+        dst[col] = smem_i32[col];
     }
     }
 }
 }
 
 
 template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
 template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
 template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
 template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
 
 
+typedef void (argsort_merge_t)(
+        constant   ggml_metal_kargs_argsort_merge & args,
+        device const char    * src0,
+        device const int32_t * tmp,
+        device       int32_t * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_merge_f32_i32(
+        constant   ggml_metal_kargs_argsort_merge & args,
+        device const char    * src0,
+        device const int32_t * tmp,
+        device       int32_t * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    int im  = tgpig[0] / args.ne01;
+    int i01 = tgpig[0] % args.ne01;
+    int i02 = tgpig[1];
+    int i03 = tgpig[2];
+
+    const int start = im * (2*args.len);
+
+    const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
+    const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
+
+    const int total = len0 + len1;
+
+    device const int32_t * tmp0 = tmp + start
+        + i01*args.ne00
+        + i02*args.ne00*args.ne01
+        + i03*args.ne00*args.ne01*args.ne02;
+
+    device const int32_t * tmp1 = tmp0 + args.len;
+
+    dst += start
+        + i01*args.ne00
+        + i02*args.ne00*args.ne01
+        + i03*args.ne00*args.ne01*args.ne02;
+
+    device const float * src0_row = (device const float *)(src0
+        + args.nb01*i01
+        + args.nb02*i02
+        + args.nb03*i03);
+
+    for (int k = tpitg.x; k < (int) total; k += ntg.x) {
+        // find partition (i,j) such that i+j = k
+        int low  = k > len1 ? k - len1 : 0;
+        int high = MIN(k, len0);
+
+        while (low < high) {
+            const int mid = (low + high) >> 1;
+
+            const int32_t idx0 = tmp0[mid];
+            const int32_t idx1 = tmp1[k - mid - 1];
+
+            const float val0 = src0_row[idx0];
+            const float val1 = src0_row[idx1];
+
+            if (order == GGML_SORT_ORDER_ASC) {
+                if (val0 <= val1) {
+                    low = mid + 1;
+                } else {
+                    high = mid;
+                }
+            } else {
+                if (val0 >= val1) {
+                    low = mid + 1;
+                } else {
+                    high = mid;
+                }
+            }
+        }
+
+        const int i = low;
+        const int j = k - i;
+
+        int32_t out_idx;
+
+        if (i >= len0) {
+            out_idx = tmp1[j];
+        } else if (j >= len1) {
+            out_idx = tmp0[i];
+        } else {
+            const int32_t idx0 = tmp0[i];
+            const int32_t idx1 = tmp1[j];
+
+            const float val0 = src0_row[idx0];
+            const float val1 = src0_row[idx1];
+
+            out_idx = (order == GGML_SORT_ORDER_ASC)
+                ? (val0 <= val1 ? idx0 : idx1)
+                : (val0 >= val1 ? idx0 : idx1);
+        }
+
+        dst[k] = out_idx;
+    }
+}
+
+template [[host_name("kernel_argsort_merge_f32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
+
 kernel void kernel_leaky_relu_f32(
 kernel void kernel_leaky_relu_f32(
         constant     ggml_metal_kargs_leaky_relu & args,
         constant     ggml_metal_kargs_leaky_relu & args,
         device const float * src0,
         device const float * src0,

+ 6 - 1
tests/test-backend-ops.cpp

@@ -7492,8 +7492,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
-        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
     }
     }