Просмотр исходного кода

cuda : support Falcon-H1 state size for SSM_SCAN (#14602)

compilade 6 месяцев назад
Родитель
Сommit
a57d1bcb3c
3 измененных файлов с 16 добавлено и 4 удалено
  1. 2 2
      ggml/src/ggml-cuda/ggml-cuda.cu
  2. 13 2
      ggml/src/ggml-cuda/ssm-scan.cu
  3. 1 0
      tests/test-backend-ops.cpp

+ 2 - 2
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SSM_SCAN: {
         case GGML_OP_SSM_SCAN: {
             if (op->src[3]->ne[0] == 1) {
             if (op->src[3]->ne[0] == 1) {
                 // Mamba2
                 // Mamba2
-                // (kernel only supports d_state == 128 && d_head % 16 == 0)
-                return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
+                // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
+                return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
             } else {
             } else {
                 // Mamba
                 // Mamba
                 // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
                 // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)

+ 13 - 2
ggml/src/ggml-cuda/ssm-scan.cu

@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
                               cudaStream_t stream) {
-    const int threads = 128;
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     if (src3_nb1 == sizeof(float)) {
     if (src3_nb1 == sizeof(float)) {
         // Mamba-2
         // Mamba-2
         if (d_state == 128) {
         if (d_state == 128) {
+            const int threads = 128;
             GGML_ASSERT(d_state % threads == 0);
             GGML_ASSERT(d_state % threads == 0);
             // NOTE: can be any power of two between 4 and 64
             // NOTE: can be any power of two between 4 and 64
             const int splitH = 16;
             const int splitH = 16;
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+        } else if (d_state == 256) { // Falcon-H1
+            const int threads = 256;
+            // NOTE: can be any power of two between 8 and 64
+            const int splitH = 16;
+            GGML_ASSERT(head_dim % splitH == 0);
+            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
+            ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
         } else {
         } else {
-            GGML_ABORT("doesn't support d_state!=128.");
+            GGML_ABORT("doesn't support d_state!=(128 or 256).");
         }
         }
     } else {
     } else {
+        const int threads = 128;
         // Mamba-1
         // Mamba-1
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(head_dim == 1);
         GGML_ASSERT(head_dim == 1);

+ 1 - 0
tests/test-backend-ops.cpp

@@ -5069,6 +5069,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
 
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64,  8, 2, 32, 4)); // Falcon-H1
 
 
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));