Просмотр исходного кода

metal : optimize FA kernels (#10171)

* ggml : add ggml_flash_attn_ext_get_prec

* metal : use F16 precision in FA kernels

ggml-ci

* metal : minor clean-up

* metal : compile-guard bf16 FA kernels

ggml-ci

* build : remove obsolete compile flag [no ci]

* metal : prevent int overflows [no ci]

* cuda : disable BF16 FA

ggml-ci

* metal : fix BF16 requirement for FA kernels

ggml-ci

* make : clean-up [no ci]
Georgi Gerganov 1 год назад
Родитель
Сommit
841f27abdb

+ 3 - 0
examples/llama-bench/llama-bench.cpp

@@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
     if (s == "f16") {
         return GGML_TYPE_F16;
     }
+    if (s == "bf16") {
+        return GGML_TYPE_BF16;
+    }
     if (s == "q8_0") {
         return GGML_TYPE_Q8_0;
     }

+ 3 - 0
ggml/include/ggml.h

@@ -1746,6 +1746,9 @@ extern "C" {
             struct ggml_tensor * a,
             enum ggml_prec       prec);
 
+    GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
+            const struct ggml_tensor * a);
+
     // TODO: needs to be adapted to ggml_flash_attn_ext
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,

+ 3 - 0
ggml/src/ggml-cuda.cu

@@ -3159,6 +3159,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
 #endif
+            if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
+                return false;
+            }
             if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
                 return true;
             }

+ 5 - 5
ggml/src/ggml-cuda/fattn.cu

@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
 
-    const int32_t precision = KQV->op_params[3];
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    if (precision != GGML_PREC_DEFAULT) {
+    if (prec != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
             constexpr int cols_per_block = 16;
             switch (Q->ne[0]) {
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const int32_t precision = KQV->op_params[3];
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
     if (cc >= CC_OFFSET_AMD) {
-        if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
-        if (precision == GGML_PREC_DEFAULT) {
+        if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
         } else if(Q->ne[0] <= 128) {

+ 56 - 18
ggml/src/ggml-metal.m

@@ -269,6 +269,12 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
             struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
             id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
             kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
+            GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+                    (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+                    (int) kernel->pipeline.threadExecutionWidth); \
             [metal_function release]; \
             if (error) { \
                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && has_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm);
@@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && has_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && has_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
     const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
     const uint64_t nb21 = src2 ? src2->nb[1] : 0;
     const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-    const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+    const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
 
     const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
     const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
                                               }
                                 }
                             } break;
+                        case GGML_TYPE_BF16:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
                         case GGML_TYPE_Q4_0:
                             {
                                 switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
                             {
                                 switch (src1->type) {
                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
                             {
                                 switch (src1->type) {
                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
                 [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
                 [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
                 [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
-                [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17];
-                [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18];
-                [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19];
-                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20];
-                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21];
-                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22];
-                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23];
-                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24];
-                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25];
-                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26];
-                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27];
-                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
+                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:17];
+                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:18];
+                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:19];
+                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:20];
+                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:21];
+                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:22];
+                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:23];
+                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:24];
+                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
                     GGML_ASSERT(nqptg  % 8  == 0);
                     GGML_ASSERT(ncpsg  % 32 == 0);
 
+                    // 2*(2*ncpsg + nqptg)*(nsg)
+                    // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+                    //
                     // 16*32*(nsg)
                     // the shared memory needed for the simdgroups to load the KV cache
                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
 
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
 
                     // ne00 + 2*ncpsg*(nsg)
                     // for each query, we load it as f16 in shared memory (ne00)
-                    // and store the attention scores (nqptg x ncpsg) as f32
+                    // and store the soft_max values and the mask
                     //
-                    // 2*ne00*(nsg)
-                    // each simdgroup has a full f32 head vector in shared mem to accumulate results
+                    // ne00*(nsg)
+                    // each simdgroup has a full f16 head vector in shared mem to accumulate results
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
 

Разница между файлами не показана из-за своего большого размера
+ 360 - 278
ggml/src/ggml-metal.metal


+ 9 - 0
ggml/src/ggml.c

@@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
     ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
 }
 
+enum ggml_prec ggml_flash_attn_ext_get_prec(
+        const struct ggml_tensor * a) {
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
+
+    return (enum ggml_prec) prec_i32;
+}
+
 // ggml_flash_attn_back
 
 struct ggml_tensor * ggml_flash_attn_back(

+ 1 - 1
tests/test-backend-ops.cpp

@@ -3745,7 +3745,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                     for (int nh : { 32, }) {
                         for (int kv : { 512, 1024, }) {
                             for (int nb : { 1, 3, 32, 35, }) {
-                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                     test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
                                 }
                             }

Некоторые файлы не были показаны из-за большого количества измененных файлов