Browse Source

vulkan: handle rope with large number of rows (#18306)

Jeff Bolz 1 month ago
parent
commit
10dc500bdb

+ 14 - 3
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -1192,6 +1192,7 @@ struct vk_op_diag_mask_push_constants {
 struct vk_op_rope_push_constants {
     uint32_t rope_mode;
     uint32_t ncols;
+    uint32_t nrows;
     uint32_t n_dims;
     float freq_scale;
     uint32_t p_delta_rows;
@@ -9090,10 +9091,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
         } break;
     case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_ROPE:
-    case GGML_OP_ROPE_BACK:
         elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
         break;
+    case GGML_OP_ROPE:
+    case GGML_OP_ROPE_BACK:
+        {
+            uint32_t nrows = (uint32_t)ggml_nrows(src0);
+            uint32_t z = 1;
+            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
+                z = CEIL_DIV(nrows, 32768);
+                nrows = 32768;
+            }
+            elements = { nrows, (uint32_t)ne00, z };
+
+        } break;
     case GGML_OP_GET_ROWS:
         elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
         elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@@ -10021,7 +10032,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
     uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
 
     vk_op_rope_push_constants rope {
-        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
         has_ff, (uint32_t)src0->ne[2], nb01, nb02,
         { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,

+ 4 - 1
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp

@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_multi(i0, i1, pc);
 }

+ 4 - 1
ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp

@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_neox(i0, i1, pc);
 }

+ 4 - 1
ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp

@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_norm(i0, i1, pc);
 }

+ 1 - 0
ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl

@@ -6,6 +6,7 @@
 struct rope_params {
     uint rope_mode;
     uint ncols;
+    uint nrows;
     uint n_dims;
     float freq_scale;
     uint p_delta_rows;

+ 4 - 1
ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp

@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_vision(i0, i1, pc);
 }

+ 3 - 0
tests/test-backend-ops.cpp

@@ -7775,6 +7775,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
                                     test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
                                     test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
+                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
                                 }
 
                                 if (all) {
@@ -7789,6 +7790,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
+                                    test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1},  16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw));
                                 }
 
                                 if (all) {
@@ -7802,6 +7804,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));
                                     test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
                                     test_cases.emplace_back(new test_rope(type, {128,  16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
+                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
                                 }
 
                                 test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)