Browse Source

CUDA: add bf16 and i32 to getrows (#14529)

Aman Gupta 6 months ago
parent
commit
b9c3eefde1
2 changed files with 10 additions and 0 deletions
  1. 8 0
      ggml/src/ggml-cuda/getrows.cu
  2. 2 0
      ggml/src/ggml-cuda/ggml-cuda.cu

+ 8 - 0
ggml/src/ggml-cuda/getrows.cu

@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
             get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
             get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
             break;
+        case GGML_TYPE_I32:
+            get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+            break;
         case GGML_TYPE_BF16:
         case GGML_TYPE_BF16:
             get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
             get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -210,6 +214,10 @@ void get_rows_cuda(
             ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
             ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
             break;
+        case GGML_TYPE_I32:
+            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+            break;
         case GGML_TYPE_F16:
         case GGML_TYPE_F16:
             ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
             ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
                 ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

+ 2 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3200,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 switch (op->src[0]->type) {
                 switch (op->src[0]->type) {
                     case GGML_TYPE_F16:
                     case GGML_TYPE_F16:
                     case GGML_TYPE_F32:
                     case GGML_TYPE_F32:
+                    case GGML_TYPE_BF16:
+                    case GGML_TYPE_I32:
                     case GGML_TYPE_Q4_0:
                     case GGML_TYPE_Q4_0:
                     case GGML_TYPE_Q4_1:
                     case GGML_TYPE_Q4_1:
                     case GGML_TYPE_Q5_0:
                     case GGML_TYPE_Q5_0: