Przeglądaj źródła

CUDA: enable Gemma FA for HIP/Pascal (#9581)

Johannes Gäßler 1 rok temu
rodzic
commit
a5b57b08ce
3 zmienionych plików z 10 dodań i 10 usunięć
  1. 8 8
      ggml/src/ggml-cuda.cu
  2. 1 1
      ggml/src/ggml-cuda/fattn.cu
  3. 1 1
      tests/test-backend-ops.cpp

+ 8 - 8
ggml/src/ggml-cuda.cu

@@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV:
             return true;
-        case GGML_OP_FLASH_ATTN_EXT:
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-            return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
-#else
+        case GGML_OP_FLASH_ATTN_EXT: {
+            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+                return true;
+            }
             if (op->src[0]->ne[0] == 128) {
                 return true;
             }
-            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+            if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
-            return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
-                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
+            return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+        }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:

+ 1 - 1
ggml/src/ggml-cuda/fattn.cu

@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (!fast_fp16_available(cc)) {
-        if (Q->ne[1] <= 8) {
+        if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
         } else {
             ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);

+ 1 - 1
tests/test-backend-ops.cpp

@@ -3599,7 +3599,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                     if (hs != 128 && logit_softcap != 0.0f) continue;
                     for (int nh : { 32, }) {
                         for (int kv : { 512, 1024, }) {
-                            for (int nb : { 1, 2, 4, 8, }) {
+                            for (int nb : { 1, 3, 32, 35, }) {
                                 for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                     test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
                                 }