Jelajahi Sumber

cuda : fix supports_op condition for get_rows when number of blocks is too large (#15868)

* cuda : fix supports_op condition for get_rows when src1->ne2 > 1

ggml-ci

* ggml : add comment about ggml_get_rows

ggml-ci

* cuda : add FIXME [no ci]

* cuda : update support condition

ggml-ci
Georgi Gerganov 4 bulan lalu
induk
melakukan
b0d52998b9
3 mengubah file dengan 10 tambahan dan 1 penghapusan
  1. 5 1
      ggml/include/ggml.h
  2. 4 0
      ggml/src/ggml-cuda/ggml-cuda.cu
  3. 1 0
      ggml/src/ggml.c

+ 5 - 1
ggml/include/ggml.h

@@ -1529,7 +1529,11 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // supports 3D: a->ne[2] == b->ne[1]
+    // supports 4D a:
+    // a     [n_embd, ne1, ne2, ne3]
+    // b I32 [n_rows, ne2, ne3, 1]
+    //
+    // return [n_embd, n_rows, ne2, ne3]
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,  // data

+ 4 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3392,6 +3392,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_GET_ROWS:
             {
+                // FIXME: https://github.com/ggml-org/llama.cpp/pull/15868
+                if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) {
+                    return false;
+                }
                 switch (op->src[0]->type) {
                     case GGML_TYPE_F16:
                     case GGML_TYPE_F32:

+ 1 - 0
ggml/src/ggml.c

@@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
     GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == b->ne[2]);
     GGML_ASSERT(b->ne[3] == 1);
     GGML_ASSERT(b->type == GGML_TYPE_I32);